Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark Leone authored and Mark Leone committed Jan 29, 2025
2 parents a66c98e + 36db25d commit 9fdb297
Show file tree
Hide file tree
Showing 20 changed files with 426 additions and 384 deletions.
Binary file added stack/main/data/models/ssmr/ssmr_200g.npz
Binary file not shown.
Binary file removed stack/main/data/models/ssmr/ssmr_200g.pkl
Binary file not shown.
Binary file added stack/main/data/models/ssmr/ssmr_300g.npz
Binary file not shown.
Binary file removed stack/main/data/models/ssmr/ssmr_300g.pkl
Binary file not shown.
Binary file added stack/main/data/models/ssmr/ssmr_400g.npz
Binary file not shown.
Binary file removed stack/main/data/models/ssmr/ssmr_400g.pkl
Binary file not shown.
19 changes: 17 additions & 2 deletions stack/main/scripts/data_visualization.ipynb

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions stack/main/src/controller/controller/mpc/locp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import cvxpy as cp
import numpy as np
from scipy.linalg import block_diag
from cvxpy.atoms.affine.reshape import reshape
from functools import partial
from cvxpy.atoms.affine.reshape import reshape as cvxpy_reshape
reshape = partial(cvxpy_reshape, order='F') # future proof
import time
import scipy.sparse as sp
import jax
Expand Down Expand Up @@ -290,7 +292,7 @@ def _set_constraints(self):
# Trust region constraints
if self.tr_active:
X_scale = self.x_scale.reshape(-1, 1).repeat(self.N + 1, axis=1)
dx = cp.reshape(self.x, (self.n_x, self.N + 1)) - self.xk.T
dx = reshape(self.x, (self.n_x, self.N + 1)) - self.xk.T
dx_scaled = cp.multiply(X_scale, dx)
constr += [cp.norm(dx_scaled, 'inf', axis=0) <= self.delta + self.st]

Expand Down
72 changes: 6 additions & 66 deletions stack/main/src/controller/controller/mpc_solver_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import jax.numpy as jnp
import rclpy # type: ignore
from rclpy.node import Node # type: ignore
from rclpy.qos import QoSProfile # type: ignore
from scipy.interpolate import interp1d
from interfaces.srv import ControlSolver
from .mpc.gusto import GuSTO
Expand Down Expand Up @@ -32,7 +31,7 @@ def run_mpc_solver_node(model, config, x0, t=None, dt=None, z=None, u=None, zf=N
(https://osqp.org/docs/interfaces/solver_settings.html)
"""
assert t is not None or dt is not None, "Either t array or dt must be provided."
rclpy.init()
# rclpy.init()
node = MPCSolverNode(model, config, x0, t=t, dt=dt, z=z, u=u, zf=zf,
U=U, X=X, Xf=Xf, dU=dU, **kwargs)
rclpy.spin(node)
Expand Down Expand Up @@ -66,6 +65,7 @@ def __init__(self, model, config, x0, t=None, dt=None, z=None, u=None, zf=None,
self.model = model
if dt is None and t is not None:
self.dt = t[1] - t[0]
self.N = config.N

# Define target values
self.z = z
Expand All @@ -80,7 +80,7 @@ def __init__(self, model, config, x0, t=None, dt=None, z=None, u=None, zf=None,

# Set up GuSTO and run first solve with a simple initial guess
u_init = jnp.zeros((config.N, self.model.n_u))
x_init, _ = self.model.rollout(x0, u_init, self.dt)
x_init = self.model.rollout(x0, u_init, self.dt)
z, zf, u = self.get_target(0.0)
self.gusto = GuSTO(model, config, x0, u_init, x_init, z=z, u=u,
zf=zf, U=U, X=X, Xf=Xf, dU=dU, **kwargs)
Expand All @@ -92,6 +92,7 @@ def __init__(self, model, config, x0, t=None, dt=None, z=None, u=None, zf=None,

# Define the service, which uses the gusto callback function
self.srv = self.create_service(ControlSolver, 'mpc_solver', self.gusto_callback)
self.get_logger().info('MPC solver service has been created.')

def gusto_callback(self, request, response):
"""
Expand Down Expand Up @@ -143,8 +144,8 @@ def get_target(self, t0):
else:
z = None

# Get target zf term for cost function
if self.Qzf is not None and z is not None:
# Get target zf term for cost function
if z is not None:
zf = z[-1, :]
else:
zf = None
Expand All @@ -159,64 +160,3 @@ def get_target(self, t0):
u = None

return z, zf, u


class MPCClientNode(Node):
"""
The client side of the MPC service. This object is used to query
the ROS node to solve a GuSTO problem.
Once a MPCSolverNode is running, instantiate this object and then use
send_request to send a query the GuSTO solver.
"""

def __init__(self):
rclpy.init()
super().__init__('mpc_client')
self.cli = self.create_client(ControlSolver, 'mpc_solver')

# Wait until the solver node is up and running
while not self.cli.wait_for_service(timeout_sec=1.0):
self.get_logger().info('MPC solver not available, waiting...')

# Request message definition
self.req = ControlSolver.Request()

def send_request(self, t0, x0, wait=False):
"""
:param t0:
:param x0:
:param wait: Boolean
:return:
"""
self.req.t0 = t0
self.req.x0 = jnp2arr(x0)

self.future = self.cli.call_async(self.req)

if wait:
# Synchronous call, not compatible for real-time applications
rclpy.spin_until_future_complete(self, self.future)

def force_spin(self):
if not self.check_if_done():
rclpy.spin_once(self, timeout_sec=0)

def check_if_done(self):
return self.future.done()

def force_wait(self):
self.get_logger().warning('Overrides realtime compatibility, solve is too slow. Consider modifying problem')
rclpy.spin_until_future_complete(self, self.future)

def get_solution(self, n_x, n_u):
"""
Obtain result from MPC solver.
"""
res = self.future.result()
t = arr2jnp(res.t, 1, squeeze=True)
xopt = arr2jnp(res.xopt, n_x)
uopt = arr2jnp(res.uopt, n_u)
t_solve = res.solve_time

return t, uopt, xopt, t_solve
164 changes: 164 additions & 0 deletions stack/main/src/executor/executor/experiment_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import os
import jax
import jax.numpy as jnp
import rclpy # type: ignore
from rclpy.node import Node # type: ignore
from rclpy.qos import QoSProfile # type: ignore
from controller.mpc_solver_node import arr2jnp, jnp2arr # type: ignore
from interfaces.msg import SingleMotorControl, AllMotorsControl, TrunkRigidBodies
from interfaces.srv import ControlSolver


class RunExperimentNode(Node):
"""
This node is responsible for running the main experiment.
"""
def __init__(self):
super().__init__('run_experiment_node')
self.declare_parameters(namespace='', parameters=[
('debug', False), # False or True (print debug messages)
('controller_type', 'mpc'), # 'ik' or 'mpc' (what controller to use)
('results_name', 'test_experiment') # name of the results file
])

self.debug = self.get_parameter('debug').value
self.controller_type = self.get_parameter('controller_type').value
self.results_name = self.get_parameter('results_name').value
self.data_dir = os.getenv('TRUNK_DATA', '/home/trunk/Documents/trunk-stack/stack/main/data')

# Settled positions of the rigid bodies
self.rest_position = jnp.array([0.1005, -0.10698, 0.10445, -0.10302, -0.20407, 0.10933, 0.10581, -0.32308, 0.10566])

if self.controller_type == 'mpc':
# Subscribe to current positions
self.mocap_subscription = self.create_subscription(
TrunkRigidBodies,
'/trunk_rigid_bodies',
self.mocap_listener_callback,
QoSProfile(depth=10)
)

# Create MPC solver service client
self.mpc_client = self.create_client(
ControlSolver,
'mpc_solver'
)
self.get_logger().info('MPC client created.')
while not self.mpc_client.wait_for_service(timeout_sec=1.0):
self.get_logger().info('MPC solver not available, waiting...')
# Request message definition
self.req = ControlSolver.Request()

elif self.controller_type == 'ik':
# Create control solver service client
self.ik_client = self.create_client(
ControlSolver,
'ik_solver'
)
self.get_logger().info('IK client created.')
while not self.ik_client.wait_for_service(timeout_sec=1.0):
self.get_logger().info('IK solver not available, waiting...')
else:
raise ValueError('Invalid controller type: ' + self.controller_type + '. Valid options are: "ik" or "mpc".')

# Create publisher to execute found control inputs
# NOTE: still disabled later in code for now
self.controls_publisher = self.create_publisher(
AllMotorsControl,
'/all_motors_control',
QoSProfile(depth=10)
)

# Maintain current observations because of the delay embedding
# self.y = jnp.zeros(self.model.n_y)

self.get_logger().info('Run experiment node has been started.')

self.clock = self.get_clock()
# self.start_time = self.clock.now().nanoseconds / 1e9

def mocap_listener_callback(self, msg):
if self.debug:
self.get_logger().info(f'Received mocap data: {msg.positions}.')

# Unpack the message into simple list of positions, eg [x1, y1, z1, x2, y2, z2, ...]
y_new = jnp.array([coord for pos in msg.positions for coord in [pos.x, pos.y, pos.z]])

# Center the data around settled positions
y_centered = y_new - self.settled_position

# Subselect bottom two segments
y_centered_midtip = y_centered[3:]

# Update the current observations, including *single* delay embedding
self.y = jnp.concatenate([y_centered_midtip, self.y[:6]])

t0 = self.clock.now().nanoseconds / 1e9
x0 = self.model.encode(self.y)

# Call the service
self.mpc_client.send_request(t0, x0, wait=False)
self.mpc_client.future.add_done_callback(self.service_callback)

def service_callback(self, async_response):
try:
response = async_response.result()
# TODO: enable control execution
# self.publish_control_inputs(response.uopt)
except Exception as e:
self.get_logger().error(f'Service call failed: {e}.')

def publish_control_inputs(self, control_inputs):
control_message = AllMotorsControl()
control_message.motors_control = [
SingleMotorControl(mode=0, value=value) for value in control_inputs
]
self.controls_publisher.publish(control_message)
if self.debug:
self.get_logger().info(f'Published new motor control setting: {control_inputs}.')

def send_request(self, t0, x0, wait=False):
"""
Send request to MPC solver.
"""
self.req.t0 = t0
self.req.x0 = jnp2arr(x0)
self.future = self.mpc_client.call_async(self.req)

if wait:
# Synchronous call, not compatible for real-time applications
rclpy.spin_until_future_complete(self, self.future)

def get_solution(self, n_x, n_u):
"""
Obtain result from MPC solver.
"""
res = self.future.result()
t = arr2jnp(res.t, 1, squeeze=True)
xopt = arr2jnp(res.xopt, n_x)
uopt = arr2jnp(res.uopt, n_u)
t_solve = res.solve_time

return t, uopt, xopt, t_solve

def force_spin(self):
if not self.check_if_done():
rclpy.spin_once(self, timeout_sec=0)

def check_if_done(self):
return self.future.done()

def force_wait(self):
self.get_logger().warning('Overrides realtime compatibility, solve is too slow. Consider modifying problem')
rclpy.spin_until_future_complete(self, self.future)


def main(args=None):
rclpy.init(args=args)
run_experiment_node = RunExperimentNode()
rclpy.spin(run_experiment_node)
run_experiment_node.destroy_node()
rclpy.shutdown()

if __name__ == '__main__':
main()
86 changes: 86 additions & 0 deletions stack/main/src/executor/executor/mpc_initializer_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os
import jax
import jax.numpy as jnp
import rclpy # type: ignore
from rclpy.node import Node # type: ignore
from controller.mpc.gusto import GuSTOConfig # type: ignore
from controller.mpc_solver_node import run_mpc_solver_node # type: ignore
from .utils.models import SSMR


class MPCInitializerNode(Node):
"""
This node initializes all that is needed for MPC.
"""
def __init__(self):
super().__init__('mpc_initializer_node')
self.declare_parameters(namespace='', parameters=[
('debug', False), # False or True (print debug messages)
('model_name', 'ssmr_200g'), # 'ssmr_200g' (what model to use)
])
self.debug = self.get_parameter('debug').value
self.model_name = self.get_parameter('model_name').value
self.data_dir = os.getenv('TRUNK_DATA', '/home/trunk/Documents/trunk-stack/stack/main/data')

# Generate reference trajectory
z_ref, t = self._generate_ref_trajectory(4, 0.01, 'circle', 0.15)

# Load the model
self._load_model()

# MPC configuration
gusto_config = GuSTOConfig(
Qz=jnp.eye(self.model.n_z),
Qzf=10*jnp.eye(self.model.n_z),
R=0.0001*jnp.eye(self.model.n_u),
x_char=0.05*jnp.ones(self.model.n_x),
f_char=0.5*jnp.ones(self.model.n_x),
N=7
)
x0 = jnp.zeros(self.model.n_x)
self.mpc_solver_node = run_mpc_solver_node(self.model, gusto_config, x0, t=t, z=z_ref)

def _load_model(self):
"""
Load the learned (non-autonomous) dynamics model of the system.
"""
model_path = os.path.join(self.data_dir, f'models/ssmr/{self.model_name}.npz')

# Load the model
self.model = SSMR(model_path=model_path)
print('---- Model loaded. Dimensions:')
print(' n_x:', self.model.n_x)
print(' n_u:', self.model.n_u)
print(' n_z:', self.model.n_z)
print(' n_y:', self.model.n_y)

def _generate_ref_trajectory(self, T, dt, traj_type, size):
"""
Generate a 3D reference trajectory for the system to track.
"""
t = jnp.linspace(0, T, int(T/dt))
z_ref = jnp.zeros((len(t), 3))

# Note that y is up
if traj_type == 'circle':
z_ref = z_ref.at[:, 0].set(size * jnp.cos(2 * jnp.pi / T * t))
z_ref = z_ref.at[:, 1].set(size / 2 * jnp.ones_like(t))
z_ref = z_ref.at[:, 2].set(size * jnp.sin(2 * jnp.pi / T * t))
elif traj_type == 'figure_eight':
z_ref = z_ref.at[:, 0].set(size * jnp.sin(jnp.pi / T * t))
z_ref = z_ref.at[:, 1].set(size / 2 * jnp.ones_like(t))
z_ref = z_ref.at[:, 2].set(size * jnp.sin(2 * jnp.pi / T * t))
else:
raise ValueError('Invalid trajectory type: ' + traj_type + '. Valid options are: "circle" or "figure_eight".')
return z_ref, t


def main(args=None):
rclpy.init(args=args)
mpc_initializer_node = MPCInitializerNode()
rclpy.spin(mpc_initializer_node)
mpc_initializer_node.destroy_node()
rclpy.shutdown()

if __name__ == '__main__':
main()
Loading

0 comments on commit 9fdb297

Please sign in to comment.