Skip to content

Commit

Permalink
Updated dy_u/dy_l->dy in bindings.cpp, bindings.cpp.in, interface.py,…
Browse files Browse the repository at this point in the history
… derivatice_test.py
  • Loading branch information
AmitSolomonPrinceton committed Apr 11, 2024
1 parent 9e12594 commit a556d35
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 29 deletions.
3 changes: 2 additions & 1 deletion src/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class PyOSQPSolver {
OSQPInt update_data_mat(py::object, py::object, py::object, py::object);
OSQPInt warm_start(py::object, py::object);
OSQPInt solve();
OSQPInt adjoint_derivative_compute(py::object, py::object, py::object);
OSQPInt adjoint_derivative_compute(py::object, py::object);
OSQPInt adjoint_derivative_get_mat(CSC&, CSC&);
OSQPInt adjoint_derivative_get_vec(py::object, py::object, py::object);

Expand Down Expand Up @@ -291,6 +291,7 @@ OSQPInt PyOSQPSolver::adjoint_derivative_compute(const py::object dx, const py::
_dy = (OSQPFloat *)_dy_array.data();
}


return osqp_adjoint_derivative_compute(this->_solver, _dx, _dy);
}

Expand Down
26 changes: 9 additions & 17 deletions src/bindings.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class PyOSQPSolver {
OSQPInt update_data_mat(py::object, py::object, py::object, py::object);
OSQPInt warm_start(py::object, py::object);
OSQPInt solve();
OSQPInt adjoint_derivative_compute(py::object, py::object, py::object);
OSQPInt adjoint_derivative_compute(py::object, py::object);
OSQPInt adjoint_derivative_get_mat(CSC&, CSC&);
OSQPInt adjoint_derivative_get_vec(py::object, py::object, py::object);

Expand Down Expand Up @@ -273,10 +273,9 @@ OSQPInt PyOSQPSolver::update_data_mat(py::object P_x, py::object P_i, py::object
return osqp_update_data_mat(this->_solver, _P_x, _P_i, _P_n, _A_x, _A_i, _A_n);
}

OSQPInt PyOSQPSolver::adjoint_derivative_compute(const py::object dx, const py::object dy_l, const py::object dy_u) {
OSQPInt PyOSQPSolver::adjoint_derivative_compute(const py::object dx, const py::object dy) {
OSQPFloat* _dx;
OSQPFloat* _dy_l;
OSQPFloat* _dy_u;
OSQPFloat* _dy;

if (dx.is_none()) {
_dx = NULL;
Expand All @@ -285,22 +284,15 @@ OSQPInt PyOSQPSolver::adjoint_derivative_compute(const py::object dx, const py::
_dx = (OSQPFloat *)_dx_array.data();
}

if (dy_l.is_none()) {
_dy_l = NULL;
if (dy.is_none()) {
_dy = NULL;
} else {
auto _dy_l_array = py::array_t<OSQPFloat>(dy_l);
_dy_l = (OSQPFloat *)_dy_l_array.data();
auto _dy_array = py::array_t<OSQPFloat>(dy);
_dy = (OSQPFloat *)_dy_array.data();
}

if (dy_u.is_none()) {
_dy_u = NULL;
} else {
auto _dy_u_array = py::array_t<OSQPFloat>(dy_u);
_dy_u = (OSQPFloat *)_dy_u_array.data();
}

return osqp_adjoint_derivative_compute(this->_solver, _dx, _dy_l, _dy_u);

return osqp_adjoint_derivative_compute(this->_solver, _dx, _dy);
}

OSQPInt PyOSQPSolver::adjoint_derivative_get_mat(CSC& dP, CSC& dA) {
Expand Down Expand Up @@ -489,7 +481,7 @@ PYBIND11_MODULE(@OSQP_EXT_MODULE_NAME@, m) {
.def("update_rho", &PyOSQPSolver::update_rho)
.def("get_settings", &PyOSQPSolver::get_settings, py::return_value_policy::reference)

.def("adjoint_derivative_compute", &PyOSQPSolver::adjoint_derivative_compute, "dx"_a.none(true), "dy_l"_a.none(true), "dy_u"_a.none(true))
.def("adjoint_derivative_compute", &PyOSQPSolver::adjoint_derivative_compute, "dx"_a.none(true), "dy"_a.none(true))
.def("adjoint_derivative_get_mat", &PyOSQPSolver::adjoint_derivative_get_mat, "dP"_a, "dA"_a)
.def("adjoint_derivative_get_vec", &PyOSQPSolver::adjoint_derivative_get_vec, "dq"_a, "dl"_a, "du"_a)

Expand Down
10 changes: 4 additions & 6 deletions src/osqp/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def codegen(

return folder

def adjoint_derivative_compute(self, dx=None, dy_l=None, dy_u=None):
def adjoint_derivative_compute(self, dx=None, dy=None):
"""
Compute adjoint derivative after solve.
"""
Expand All @@ -450,12 +450,10 @@ def adjoint_derivative_compute(self, dx=None, dy_l=None, dy_u=None):
if results.info.status != 'solved':
raise ValueError('Problem has not been solved to optimality. ' 'You cannot take derivatives')

if dy_u is None:
dy_u = np.zeros(self.m)
if dy_l is None:
dy_l = np.zeros(self.m)
if dy is None:
dy = np.zeros(self.m)

self._solver.adjoint_derivative_compute(dx, dy_l, dy_u)
self._solver.adjoint_derivative_compute(dx, dy)

def adjoint_derivative_get_mat(self, as_dense=True, dP_as_triu=True):
"""
Expand Down
8 changes: 3 additions & 5 deletions src/osqp/tests/derivative_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_prob(self, n=10, m=3, P_scale=1.0, A_scale=1.0):

return [P, q, A, l, u, true_x, true_yl, true_yu]

def get_grads(self, P, q, A, l, u, true_x, true_yl=None, true_yu=None, mode='qdldl'):
def get_grads(self, P, q, A, l, u, true_x, true_y=None, mode='qdldl'):
# Get gradients by solving with osqp
m = osqp.OSQP(algebra='builtin')
m.setup(
Expand All @@ -66,12 +66,10 @@ def get_grads(self, P, q, A, l, u, true_x, true_yl=None, true_yu=None, mode='qdl
raise ValueError('Problem not solved!')
x = results.x
y = results.y
yl = -np.minimum(y, 0)
yu = np.maximum(y, 0)
if true_yl is None and true_yu is None:
if true_y is None:
m.adjoint_derivative_compute(dx=x - true_x)
else:
m.adjoint_derivative_compute(dx=x - true_x, dy_l=yl - true_yl, dy_u=yu - true_yu)
m.adjoint_derivative_compute(dx=x - true_x, dy=y - true_y)

dP, dA = m.adjoint_derivative_get_mat()
dq, dl, du = m.adjoint_derivative_get_vec()
Expand Down

0 comments on commit a556d35

Please sign in to comment.