Skip to content

Commit

Permalink
Add a verbose option to BFGS. (#95)
Browse files Browse the repository at this point in the history
* adding a verbose option to BFGS

* adding default values for verbose option to custom BFGS solvers

---------

Co-authored-by: Johanna Haffner <[email protected]>
  • Loading branch information
johannahaffner and Johanna Haffner authored Dec 21, 2024
1 parent 1d8e79b commit eb8d50e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
23 changes: 23 additions & 0 deletions optimistix/_solver/bfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
max_norm,
tree_dot,
tree_full_like,
verbose_print,
)
from .._search import (
AbstractDescent,
Expand Down Expand Up @@ -162,6 +163,7 @@ class AbstractBFGS(
use_inverse: AbstractVar[bool]
descent: AbstractVar[AbstractDescent[Y, _Hessian, Any]]
search: AbstractVar[AbstractSearch[Y, _Hessian, FunctionInfo.Eval, Any]]
verbose: AbstractVar[frozenset[str]]

def init(
self,
Expand Down Expand Up @@ -252,6 +254,20 @@ def rejected(descent_state):
accept, accepted, rejected, state.descent_state
)

if len(self.verbose) > 0:
verbose_loss = "loss" in self.verbose
verbose_step_size = "step_size" in self.verbose
verbose_y = "y" in self.verbose
loss_eval = f_eval
loss = state.f_info.f
verbose_print(
(verbose_loss, "Loss on this step", loss_eval),
(verbose_loss, "Loss on the last accepted step", loss),
(verbose_step_size, "Step size", step_size),
(verbose_y, "y", state.y_eval),
(verbose_y, "y on the last accepted step", y),
)

y_descent, descent_result = self.descent.step(step_size, descent_state)
y_eval = (y**ω + y_descent**ω).ω
result = RESULTS.where(
Expand Down Expand Up @@ -310,13 +326,15 @@ class BFGS(AbstractBFGS[Y, Aux, _Hessian], strict=True):
use_inverse: bool
descent: NewtonDescent
search: BacktrackingArmijo
verbose: frozenset[str]

def __init__(
self,
rtol: float,
atol: float,
norm: Callable[[PyTree], Scalar] = max_norm,
use_inverse: bool = True,
verbose: frozenset[str] = frozenset(),
):
self.rtol = rtol
self.atol = atol
Expand All @@ -325,6 +343,7 @@ def __init__(
self.descent = NewtonDescent(linear_solver=lx.Cholesky())
# TODO(raderj): switch out `BacktrackingArmijo` with a better line search.
self.search = BacktrackingArmijo()
self.verbose = verbose


BFGS.__init__.__doc__ = """**Arguments:**
Expand All @@ -345,4 +364,8 @@ def __init__(
default is (b), denoted via `use_inverse=True`. Note that this is incompatible with
line search methods like [`optimistix.ClassicalTrustRegion`][], which use the
Hessian approximation `B` as part of their own computations.
- `verbose`: Whether to print out extra information about how the solve is
proceeding. Should be a frozenset of strings, specifying what information to print.
Valid entries are `step_size`, `loss`, `y`. For example
`verbose=frozenset({"step_size", "loss"})`.
"""
5 changes: 5 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class BFGSDampedNewton(optx.AbstractBFGS):
use_inverse: bool = False
search: optx.AbstractSearch = optx.ClassicalTrustRegion()
descent: optx.AbstractDescent = optx.DampedNewtonDescent()
verbose: frozenset[str] = frozenset()


class BFGSIndirectDampedNewton(optx.AbstractBFGS):
Expand All @@ -131,6 +132,7 @@ class BFGSIndirectDampedNewton(optx.AbstractBFGS):
use_inverse: bool = False
search: optx.AbstractSearch = optx.ClassicalTrustRegion()
descent: optx.AbstractDescent = optx.IndirectDampedNewtonDescent()
verbose: frozenset[str] = frozenset()


class BFGSDogleg(optx.AbstractBFGS):
Expand All @@ -142,6 +144,7 @@ class BFGSDogleg(optx.AbstractBFGS):
use_inverse: bool = False
search: optx.AbstractSearch = optx.ClassicalTrustRegion()
descent: optx.AbstractDescent = optx.DoglegDescent(linear_solver=lx.SVD())
verbose: frozenset[str] = frozenset()


class BFGSBacktracking(optx.AbstractBFGS):
Expand All @@ -153,6 +156,7 @@ class BFGSBacktracking(optx.AbstractBFGS):
use_inverse: bool = False
search: optx.AbstractSearch = optx.BacktrackingArmijo()
descent: optx.AbstractDescent = optx.NewtonDescent()
verbose: frozenset[str] = frozenset()


class BFGSTrustRegion(optx.AbstractBFGS):
Expand All @@ -164,6 +168,7 @@ class BFGSTrustRegion(optx.AbstractBFGS):
use_inverse: bool = False
search: optx.AbstractSearch = optx.LinearTrustRegion()
descent: optx.AbstractDescent = optx.NewtonDescent()
verbose: frozenset[str] = frozenset()


atol = rtol = 1e-8
Expand Down

0 comments on commit eb8d50e

Please sign in to comment.