Skip to content

Commit

Permalink
add possibility to set pinn model from rpc port
Browse files Browse the repository at this point in the history
  • Loading branch information
ergocub committed Oct 3, 2024
1 parent 01bc68d commit 207f8d0
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ struct PINNParameters
{
std::string modelPath; /**< PINN model path */
int threadNumber; /**< number of threads */
int modelType; /**< type of the model */

/**
* Reset the parameters
Expand All @@ -96,6 +97,7 @@ struct PINNParameters
{
modelPath = "";
threadNumber = 0;
modelType = 0;
}
};

Expand Down Expand Up @@ -185,7 +187,6 @@ class BipedalLocomotion::JointTorqueControlDevice
std::vector<double> m_motorPositionsRadians;
std::vector<std::string> m_axisNames;
LowPassFilterParameters m_lowPassFilterParameters;
int m_modelType{0};

yarp::os::Port m_rpcPort; /**< Remote Procedure Call port. */

Expand Down Expand Up @@ -295,6 +296,8 @@ class BipedalLocomotion::JointTorqueControlDevice
virtual double getMaxFrictionTorque(const std::string& jointName) override;
virtual bool setFrictionModel(const std::string& jointName, const std::string& model) override;
virtual std::string getFrictionModel(const std::string& jointName) override;
virtual bool setPINNModel(const std::string& jointName, const std::string& pinnModelName, const int modelType) override;
virtual std::string getPINNModel(const std::string& jointName) override;
};

#endif // BIPEDAL_LOCOMOTION_FRAMEWORK_JOINT_TORQUE_CONTROL_DEVICE_H
Expand Down
66 changes: 60 additions & 6 deletions devices/JointTorqueControlDevice/src/JointTorqueControlDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,60 @@ std::string JointTorqueControlDevice::getFrictionModel(const std::string& jointN
return model;
}

bool JointTorqueControlDevice::setPINNModel(const std::string& jointName,
const std::string& pinnModelName,
const int modelType)
{
auto it = std::find(m_axisNames.begin(), m_axisNames.end(), jointName);

// If jointName is found
if (it != m_axisNames.end())
{
// Calculate the index of the found element
size_t index = std::distance(m_axisNames.begin(), it);

pinnParameters[index].modelPath = pinnModelName;
pinnParameters[index].modelType = modelType;

// Lock the mutex to safely modify motorTorqueCurrentParameters
std::lock_guard<std::mutex> lock(mutexTorqueControlParam_);

if (!frictionEstimators[index]->initialize(pinnParameters[index].modelPath,
pinnParameters[index].threadNumber,
pinnParameters[index].threadNumber,
pinnParameters[index].modelType))
{
log()->error("[JointTorqueControlDevice::setPINNModel] Failed to re-initialize friction estimator with model {}", pinnModelName);
return false;
}

return true;
}

return false;
}

std::string JointTorqueControlDevice::getPINNModel(const std::string& jointName)
{
std::string pinnModelName = "none";

size_t index = 0;

do
{
if (m_axisNames[index] == jointName)
{
std::lock_guard<std::mutex> lock(mutexTorqueControlParam_);

return pinnParameters[index].modelPath;
}

index++;
} while (index < m_axisNames.size());

return pinnModelName;
}

// HIJACKING CONTROL
void JointTorqueControlDevice::startHijackingTorqueControlIfNecessary(int j)
{
Expand Down Expand Up @@ -731,7 +785,8 @@ bool JointTorqueControlDevice::loadFrictionParams(
return false;
}

if (!frictionGroup->getParameter("model_type", m_modelType))
int modelType;
if (!frictionGroup->getParameter("model_type", modelType))
{
log()->error("{} Parameter `model_type` not found", logPrefix);
// return false;
Expand All @@ -741,6 +796,7 @@ bool JointTorqueControlDevice::loadFrictionParams(
{
pinnParameters[i].modelPath = models[i];
pinnParameters[i].threadNumber = threads;
pinnParameters[i].modelType = modelType;
}
}

Expand Down Expand Up @@ -871,18 +927,16 @@ bool JointTorqueControlDevice::open(yarp::os::Searchable& config)
return false;
}

int threadNumber = 1;

for (int i = 0; i < kt.size(); i++)
{
if (motorTorqueCurrentParameters[i].frictionModel == "FRICTION_PINN")
{
frictionEstimators[i] = std::make_unique<PINNFrictionEstimator>();

if (!frictionEstimators[i]->initialize(pinnParameters[i].modelPath,
threadNumber,
threadNumber,
m_modelType))
pinnParameters[i].threadNumber,
pinnParameters[i].threadNumber,
pinnParameters[i].modelType))
{
log()->error("{} Failed to initialize friction estimator", logPrefix);
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,9 @@ service JointTorqueControlCommands
bool setFrictionModel(1:string jointName, 2:string model);

string getFrictionModel(1:string jointName);

bool setPINNModel(1:string jointName, 2:string pinnModelName, 3:i32 modelType);

string getPINNModel(1:string jointName);
}

0 comments on commit 207f8d0

Please sign in to comment.