Skip to content

Commit

Permalink
post SRC demo commit, static models working for 200g internal weight
Browse files Browse the repository at this point in the history
  • Loading branch information
hbuurmei committed Jan 14, 2025
1 parent b9a48b8 commit 585e801
Show file tree
Hide file tree
Showing 10 changed files with 17,643 additions and 41 deletions.
17,603 changes: 17,603 additions & 0 deletions stack/main/data/trajectories/dynamic/TEST_controlled_0g_ml.csv

Large diffs are not rendered by default.

Empty file.
4 changes: 2 additions & 2 deletions stack/main/src/controller/controller/mpc_solver_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from rclpy.qos import QoSProfile # type: ignore
from scipy.interpolate import interp1d
from interfaces.srv import ControlSolver
from controller.mpc.gusto import GuSTO
from .mpc.gusto import GuSTO


def run_mpc_solver_node(model, config, x0, t=None, dt=None, z=None, u=None, zf=None,
Expand Down Expand Up @@ -182,7 +182,7 @@ def __init__(self):
# Request message definition
self.req = ControlSolver.Request()

def send_request(self, t0, x0, wait=True):
def send_request(self, t0, x0, wait=False):
"""
:param t0:
:param x0:
Expand Down
65 changes: 32 additions & 33 deletions stack/main/src/executor/executor/run_experiment_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import rclpy # type: ignore
from rclpy.node import Node # type: ignore
from rclpy.qos import QoSProfile # type: ignore
from controller.controller.mpc_solver_node import run_mpc_solver_node, MPCClientNode
from controller.controller.mpc import GuSTOConfig
from controller.mpc_solver_node import run_mpc_solver_node, MPCClientNode # type: ignore
from controller.mpc.gusto import GuSTOConfig # type: ignore
from interfaces.msg import SingleMotorControl, AllMotorsControl, TrunkRigidBodies
from interfaces.srv import ControlSolver
from .models.ssm import DelaySSM
from .models.models import SSMR
from .models.residual import ResidualBr, NeuralBr
from . import utils
from .utils.ssm import DelaySSM
from .utils.models import SSMR
from .utils.residual import ResidualBr, NeuralBr


class RunExperimentNode(Node):
Expand All @@ -23,26 +24,27 @@ def __init__(self):
self.declare_parameters(namespace='', parameters=[
('debug', False), # False or True (print debug messages)
('experiment_type', 'traj'), # 'traj' or 'user' (what input is being tracked)
('model', 'poly'), # 'nn' or 'poly' (what model to use)
('model_name', 'ssmr_200g'), # 'nn' or 'poly' (what model to use)
('controller_type', 'mpc'), # 'ik' or 'mpc' (what controller to use)
('output_name', 'base_experiment') # name of the output file
('results_name', 'base_experiment') # name of the results file
])

self.debug = self.get_parameter('debug').value
self.experiment_type = self.get_parameter('experiment_type').value
self.model_type = self.get_parameter('model').value
self.model_name = self.get_parameter('model_name').value
self.controller_type = self.get_parameter('controller_type').value
self.output_name = self.get_parameter('output_name').value
self.results_name = self.get_parameter('results_name').value

self.data_dir = os.getenv('TRUNK_DATA', '/home/asl/Documents/asl_trunk_ws/data')
self.data_dir = os.getenv('TRUNK_DATA', '/home/trunk/Documents/trunk-stack/stack/main/data')

# Settled positions of the rigid bodies
self.settled_positions = jnp.array([0, -0.10665, 0, 0, -0.20432, 0, 0, -0.320682, 0])
self.rest_position = jnp.array([0.1005, -0.10698, 0.10445, -0.10302, -0.20407, 0.10933, 0.10581, -0.32308, 0.10566])
self.avp_offset = jnp.array([0, -0.10698, 0, 0, -0.20407, 0, 0, -0.32308, 0])

# Get desired states
if self.experiment_type == 'trajectory':
if self.experiment_type == 'traj':
# Generate reference trajectory
z_ref, t = self._generate_ref_trajectory(3, 0.01, 'circle', 0.1)
z_ref, t = self._generate_ref_trajectory(4, 0.01, 'circle', 0.15)
elif self.experiment_type == 'user':
# We track positions as defined by the user (through the AVP)
self.avp_subscription = self.create_subscription(
Expand All @@ -69,9 +71,9 @@ def __init__(self):
Qz=jnp.eye(self.model.n_z),
Qzf=10*jnp.eye(self.model.n_z),
R=0.0001*jnp.eye(self.model.n_u),
x_char=jnp.ones(self.model.n_x),
f_char=jnp.ones(self.model.n_x),
N=6
x_char=0.05*jnp.ones(self.model.n_x),
f_char=0.5*jnp.ones(self.model.n_x),
N=7
)
x0 = jnp.zeros(self.model.n_x)
self.mpc_solver_node = run_mpc_solver_node(self.model, gusto_config, x0, t=t, z=z_ref)
Expand All @@ -90,31 +92,27 @@ def __init__(self):
raise ValueError('Invalid controller type: ' + self.controller_type + '. Valid options are: "ik" or "mpc".')

# Create publisher to execute found control inputs
""" TODO: enable control execution
self.controls_publisher = self.create_publisher(
AllMotorsControl,
'/all_motors_control',
QoSProfile(depth=10)
)
self.clock = self.get_clock()
self.start_time = self.clock.now().nanoseconds / 1e9

"""
# Maintain current observations because of the delay embedding
self.y = jnp.zeros(6)
self.y = jnp.zeros(self.model.n_y)

self.get_logger().info('Run experiment node has been started.')

self.clock = self.get_clock()
# self.start_time = self.clock.now().nanoseconds / 1e9

def _load_model(self):
"""
Load the learned dynamics model of the system used for control.
"""
# Get location of model file
if self.model_type == 'nn':
model_file = os.path.join(self.data_dir, 'models/nn_ssmr.pkl')
elif self.model_type == 'poly':
model_file = os.path.join(self.data_dir, 'models/poly_ssmr.pkl')
else:
raise ValueError('Invalid model type: ' + self.model_type + '. Valid options are: "nn" or "poly".')

model_file = os.path.join(self.data_dir, f'models/ssmr/{self.model_name}.pkl')

# Load the model
with open(model_file, 'rb') as f:
self.model = dill.load(f)
Expand Down Expand Up @@ -152,13 +150,13 @@ def mocap_listener_callback(self, msg):
y_new = jnp.array([coord for pos in msg.positions for coord in [pos.x, pos.y, pos.z]])

# Center the data around settled positions
y_centered = y_new - self.settled_positions
y_centered = y_new - self.settled_position

# Currently we only use the tip positions
y_centered_tip = y_centered[-3:]
# Subselect bottom two segments
y_centered_midtip = y_centered[3:]

# Update the current observations, including *single* delay embedding
self.y = jnp.concatenate([y_centered_tip, self.y[:3]])
self.y = jnp.concatenate([y_centered_midtip, self.y[:6]])

t0 = self.clock.now().nanoseconds / 1e9
x0 = self.model.encode(self.y)
Expand Down Expand Up @@ -188,7 +186,8 @@ def teleop_listener_callback(self, msg):
def service_callback(self, async_response):
try:
response = async_response.result()
self.publish_control_inputs(response.uopt)
# TODO: enable control execution
# self.publish_control_inputs(response.uopt)
except Exception as e:
self.get_logger().error(f'Service call failed: {e}.')

Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import jax.numpy as jnp
import copy
from functools import partial
from utils.ssm import DelaySSM, generate_ssm_predictions
from utils.residual import ResidualBr, PolyBr
from utils.misc import trajectories_delay_embedding, trajectories_derivatives, RK4_step, update_parameter, compute_rmse, sample_truncated_normal
from .ssm import DelaySSM, generate_ssm_predictions
from .residual import ResidualBr, PolyBr
from .misc import trajectories_delay_embedding, trajectories_derivatives, RK4_step, update_parameter, compute_rmse, sample_truncated_normal


class ReducedOrderModel:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import jax
import jax.numpy as jnp
from utils.nn import MLP
from utils.misc import polynomial_features, fit_linear_regression
from .nn import MLP
from .misc import polynomial_features, fit_linear_regression


class ResidualNN:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from scipy.linalg import orth
import numpy as np
import sympy as sp
from utils.misc import trajectories_delay_embedding
from .misc import trajectories_delay_embedding


class DelaySSM:
Expand Down

0 comments on commit 585e801

Please sign in to comment.