Skip to content

Commit

Permalink
Add possibility to test PINN with different inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
ergocub committed Sep 24, 2024
1 parent 9afee9a commit 01bc68d
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ class BipedalLocomotion::JointTorqueControlDevice
std::vector<double> m_motorPositionsRadians;
std::vector<std::string> m_axisNames;
LowPassFilterParameters m_lowPassFilterParameters;
int m_modelType{0};

yarp::os::Port m_rpcPort; /**< Remote Procedure Call port. */

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ class PINNFrictionEstimator
* @param[in] modelPath a string representing the path to the ONNX model
* @param[in] intraOpNumThreads a std::size_t representing the number of threads to be used for intra-op parallelism
* @param[in] interOpNumThreads a std::size_t representing the number of threads to be used for inter-op parallelism
* @param[in] modelType a std::size_t representing the type of the model
* @return true if the initialization is successful, false otherwise
*/
bool initialize(const std::string& modelPath,
const std::size_t intraOpNumThreads = 1,
const std::size_t interOpNumThreads = 1);
const std::size_t interOpNumThreads = 1,
const std::size_t modelType = 0);

/**
* Reset the estimator
Expand All @@ -47,12 +49,20 @@ class PINNFrictionEstimator

/**
* Estimate the joint friction starting from raw data
* @param[in] inputJointPositon a double representing the joint position (rad)
* @param[in] inputMotorPosition a double representing the motor position motor side (rad)
* @param[in] inputMotorVelocity a double representing the motor velocity (rad/sec)
* @param[in] inputDeltaPosition a double representing difference between the joint position and the motor position motor side (rad)
* @param[in] inputJointVelocity a double representing the joint velocity (rad/sec)
* @param[out] output a double representing the joint friction torque
* @return true if the estimation is successful, false otherwise
*/
bool estimate(double inputDeltaPosition, double inputJointVelocity, double& output);
bool estimate(double inputJointPositon,
double inputMotorPosition,
double inputMotorVelocity,
double inputDeltaPosition,
double inputJointVelocity,
double& output);


private:
Expand Down
14 changes: 12 additions & 2 deletions devices/JointTorqueControlDevice/src/JointTorqueControlDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,10 @@ double JointTorqueControlDevice::computeFrictionTorque(int joint)
m_motorPositionError[joint] = m_tempJointPosMotorSideRad - m_motorPositionCorrected[joint];

// Test network with inputs position error motor side, joint velocity
if (!frictionEstimators[joint]->estimate(m_motorPositionError[joint],
if (!frictionEstimators[joint]->estimate(m_tempJointPosRad,
m_motorPositionsRadians[joint],
measuredMotorVelocities[joint] * M_PI / 180.0,
m_motorPositionError[joint],
m_jointVelRadSec,
frictionTorque))
{
Expand Down Expand Up @@ -728,6 +731,12 @@ bool JointTorqueControlDevice::loadFrictionParams(
return false;
}

if (!frictionGroup->getParameter("model_type", m_modelType))
{
log()->error("{} Parameter `model_type` not found", logPrefix);
// return false;
}

for (int i = 0; i < models.size(); i++)
{
pinnParameters[i].modelPath = models[i];
Expand Down Expand Up @@ -872,7 +881,8 @@ bool JointTorqueControlDevice::open(yarp::os::Searchable& config)

if (!frictionEstimators[i]->initialize(pinnParameters[i].modelPath,
threadNumber,
threadNumber))
threadNumber,
m_modelType))
{
log()->error("{} Failed to initialize friction estimator", logPrefix);
return false;
Expand Down
115 changes: 104 additions & 11 deletions devices/JointTorqueControlDevice/src/PINNFrictionEstimator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,14 @@ struct PINNFrictionEstimator::Impl
std::unique_ptr<Ort::Session> session;
Ort::MemoryInfo memoryInfo;

std::deque<float> jointPositionBuffer;
std::deque<float> jointVelocityBuffer;
std::deque<float> errorPositionBuffer;
std::deque<float> motorPositionBuffer;
std::deque<float> motorVelocityBuffer;

size_t historyLength;
size_t modelType;

struct DataStructured
{
Expand All @@ -58,8 +62,9 @@ PINNFrictionEstimator::PINNFrictionEstimator()
PINNFrictionEstimator::~PINNFrictionEstimator() = default;

bool PINNFrictionEstimator::initialize(const std::string& networkModelPath,
const std::size_t intraOpNumThreads,
const std::size_t interOpNumThreads)
const std::size_t intraOpNumThreads,
const std::size_t interOpNumThreads,
const std::size_t modelType)
{
std::basic_string<ORTCHAR_T> networkModelPathAsOrtString(networkModelPath.begin(),
networkModelPath.end());
Expand All @@ -80,7 +85,6 @@ bool PINNFrictionEstimator::initialize(const std::string& networkModelPath,
networkModelPathAsOrtString.c_str(),
sessionOptions);


if (m_pimpl->session == nullptr)
{
BipedalLocomotion::log()->error("Unable to load the model from the file: {}", networkModelPath);
Expand All @@ -91,15 +95,41 @@ bool PINNFrictionEstimator::initialize(const std::string& networkModelPath,
std::vector<int64_t> inputShape = m_pimpl->session->GetInputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();

const std::size_t inputCount = inputShape[1];
m_pimpl->historyLength = inputCount / 2;

// Check the model type
m_pimpl->modelType = modelType;

int numberOfInputs = 0;
if (modelType == 1 || modelType == 2)
{
numberOfInputs = 2;
}
else if (modelType == 3 || modelType == 4)
{
numberOfInputs = 3;
}
else if (modelType == 5)
{
numberOfInputs = 4;
}
else
{
BipedalLocomotion::log()->error("Wrong model type");
return false;
}

m_pimpl->historyLength = inputCount / numberOfInputs;

// format the input
m_pimpl->structuredInput.rawData.resize(m_pimpl->historyLength * 2);
m_pimpl->structuredInput.rawData.resize(inputCount);
m_pimpl->structuredInput.shape[0] = 1; // batch
m_pimpl->structuredInput.shape[1] = m_pimpl->historyLength * 2;
m_pimpl->structuredInput.shape[1] = inputCount;

m_pimpl->jointPositionBuffer.resize(m_pimpl->historyLength);
m_pimpl->jointVelocityBuffer.resize(m_pimpl->historyLength);
m_pimpl->errorPositionBuffer.resize(m_pimpl->historyLength);
m_pimpl->motorPositionBuffer.resize(m_pimpl->historyLength);
m_pimpl->motorVelocityBuffer.resize(m_pimpl->historyLength);

// create tensor required by onnx
m_pimpl->structuredInput.tensor
Expand Down Expand Up @@ -130,22 +160,34 @@ bool PINNFrictionEstimator::initialize(const std::string& networkModelPath,

void PINNFrictionEstimator::resetEstimator()
{
m_pimpl->jointPositionBuffer.clear();
m_pimpl->motorPositionBuffer.clear();
m_pimpl->motorVelocityBuffer.clear();
m_pimpl->jointVelocityBuffer.clear();
m_pimpl->errorPositionBuffer.clear();
}

bool PINNFrictionEstimator::estimate(double inputDeltaPosition,
double inputJointVelocity,
double& output)
bool PINNFrictionEstimator::estimate(double inputJointPositon,
double inputMotorPosition,
double inputMotorVelocity,
double inputDeltaPosition,
double inputJointVelocity,
double& output)
{
if (m_pimpl->errorPositionBuffer.size() == m_pimpl->historyLength)
{
// The buffer is full, remove the oldest element
m_pimpl->jointPositionBuffer.pop_front();
m_pimpl->motorPositionBuffer.pop_front();
m_pimpl->motorVelocityBuffer.pop_front();
m_pimpl->errorPositionBuffer.pop_front();
m_pimpl->jointVelocityBuffer.pop_front();
}

// Push element into the queue
m_pimpl->jointPositionBuffer.push_back(inputJointPositon);
m_pimpl->motorPositionBuffer.push_back(inputMotorPosition);
m_pimpl->motorVelocityBuffer.push_back(inputMotorVelocity);
m_pimpl->errorPositionBuffer.push_back(inputDeltaPosition);
m_pimpl->jointVelocityBuffer.push_back(inputJointVelocity);

Expand All @@ -160,12 +202,63 @@ bool PINNFrictionEstimator::estimate(double inputDeltaPosition,
// Copy the joint positions and then the motor positions in the
// structured input without emptying the buffer
// Use iterators to copy the data to the vector
std::copy(m_pimpl->errorPositionBuffer.cbegin(),
if (m_pimpl->modelType == 1)
{
std::copy(m_pimpl->errorPositionBuffer.cbegin(),
m_pimpl->errorPositionBuffer.cend(),
m_pimpl->structuredInput.rawData.begin());
std::copy(m_pimpl->jointVelocityBuffer.cbegin(),
std::copy(m_pimpl->jointVelocityBuffer.cbegin(),
m_pimpl->jointVelocityBuffer.cend(),
m_pimpl->structuredInput.rawData.begin() + m_pimpl->historyLength);
}
else if (m_pimpl->modelType == 2)
{
std::copy(m_pimpl->motorVelocityBuffer.cbegin(),
m_pimpl->motorVelocityBuffer.cend(),
m_pimpl->structuredInput.rawData.begin());
std::copy(m_pimpl->jointVelocityBuffer.cbegin(),
m_pimpl->jointVelocityBuffer.cend(),
m_pimpl->structuredInput.rawData.begin() + m_pimpl->historyLength);
}
else if (m_pimpl->modelType == 3)
{
std::copy(m_pimpl->errorPositionBuffer.cbegin(),
m_pimpl->errorPositionBuffer.cend(),
m_pimpl->structuredInput.rawData.begin());
std::copy(m_pimpl->motorVelocityBuffer.cbegin(),
m_pimpl->motorVelocityBuffer.cend(),
m_pimpl->structuredInput.rawData.begin() + m_pimpl->historyLength);
std::copy(m_pimpl->jointVelocityBuffer.cbegin(),
m_pimpl->jointVelocityBuffer.cend(),
m_pimpl->structuredInput.rawData.begin() + 2 * m_pimpl->historyLength);
}
else if (m_pimpl->modelType == 4)
{
std::copy(m_pimpl->errorPositionBuffer.cbegin(),
m_pimpl->errorPositionBuffer.cend(),
m_pimpl->structuredInput.rawData.begin());
std::copy(m_pimpl->jointPositionBuffer.cbegin(),
m_pimpl->jointPositionBuffer.cend(),
m_pimpl->structuredInput.rawData.begin() + m_pimpl->historyLength);
std::copy(m_pimpl->jointVelocityBuffer.cbegin(),
m_pimpl->jointVelocityBuffer.cend(),
m_pimpl->structuredInput.rawData.begin() + 2 * m_pimpl->historyLength);
}
else if (m_pimpl->modelType == 5)
{
std::copy(m_pimpl->motorPositionBuffer.cbegin(),
m_pimpl->motorPositionBuffer.cend(),
m_pimpl->structuredInput.rawData.begin());
std::copy(m_pimpl->motorVelocityBuffer.cbegin(),
m_pimpl->motorVelocityBuffer.cend(),
m_pimpl->structuredInput.rawData.begin() + m_pimpl->historyLength);
std::copy(m_pimpl->jointPositionBuffer.cbegin(),
m_pimpl->jointPositionBuffer.cend(),
m_pimpl->structuredInput.rawData.begin() + 2 * m_pimpl->historyLength);
std::copy(m_pimpl->jointVelocityBuffer.cbegin(),
m_pimpl->jointVelocityBuffer.cend(),
m_pimpl->structuredInput.rawData.begin() + 3 * m_pimpl->historyLength);
}

// perform the inference
const char* inputNames[] = {"input"};
Expand Down

0 comments on commit 01bc68d

Please sign in to comment.