Skip to content

Commit

Permalink
Fix for boundary as a control (#3816)
Browse files Browse the repository at this point in the history
* Fix for the case where the boundary condition is a control parameter.
  • Loading branch information
Ig-dolci authored Oct 23, 2024
1 parent 50cd88a commit b49191e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
3 changes: 2 additions & 1 deletion firedrake/adjoint_utils/blocks/dirichlet_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
adj_value = firedrake.Function(self.collapsed_space, vec.dat)

if adj_value.ufl_shape == () or adj_value.ufl_shape[0] <= 1:
r = adj_value.dat.data_ro.sum()
R = firedrake.FunctionSpace(self.parent_space.mesh(), "R", 0)
r = firedrake.Function(R.dual(), val=adj_value.dat.data_ro.sum())
else:
output = []
subindices = _extract_subindices(self.function_space)
Expand Down
33 changes: 33 additions & 0 deletions tests/regression/test_adjoint_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,3 +1003,36 @@ def test_cofunction_assign_functional():
assert np.allclose(float(Jhat.derivative()), 1.0)
f.assign(2.0)
assert np.allclose(Jhat(f), 2.0)


@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done
def test_bdy_control():
# Test for the case the boundary condition is a control for a
# domain with length different from 1.
mesh = IntervalMesh(10, 0, 2)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)
sol = Function(space, name="sol")
# Dirichlet boundary conditions
R = FunctionSpace(mesh, "R", 0)
a = Function(R, val=1.0)
b = Function(R, val=2.0)
bc_left = DirichletBC(space, a, 1)
bc_right = DirichletBC(space, b, 2)
bc = [bc_left, bc_right]
F = dot(grad(trial), grad(test)) * dx
problem = LinearVariationalProblem(lhs(F), rhs(F), sol, bcs=bc)
solver = LinearVariationalSolver(problem)
solver.solve()
# Analytical solution of the analytical Laplace equation is:
# u(x) = a + (b - a)/2 * x
u_analytical = a + (b - a)/2 * X[0]
der_analytical0 = assemble(derivative((u_analytical**2) * dx, a))
der_analytical1 = assemble(derivative((u_analytical**2) * dx, b))
J = assemble(sol * sol * dx)
J_hat = ReducedFunctional(J, [Control(a), Control(b)])
adj_derivatives = J_hat.derivative(options={"riesz_representation": "l2"})
assert np.allclose(adj_derivatives[0].dat.data_ro, der_analytical0.dat.data_ro)
assert np.allclose(adj_derivatives[1].dat.data_ro, der_analytical1.dat.data_ro)

0 comments on commit b49191e

Please sign in to comment.