Skip to content

Commit

Permalink
minor xtrace chnages for uniform API
Browse files Browse the repository at this point in the history
  • Loading branch information
peekxc committed Jan 19, 2024
1 parent c79bd37 commit 917e85f
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions src/primate/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,19 @@ def hutchpp(
"""Hutch++ estimator.
"""
_operator_checks(A)

## Catch degenerate cases
if (np.prod(A.shape) == 0) or (np.sum(A.shape) == 0):
## Catch degenerate cases
return 0

## If fun is specified, transparently convert A to matrix function
if isinstance(fun, str):
assert fun in _builtin_matrix_functions, "If given as a string, matrix_function be one of the builtin functions."
A = matrix_function(A, fun=fun)
elif isinstance(fun, Callable):
A = matrix_function(A, fun=fun)
elif fun is not None:
raise ValueError(f"Invalid matrix function type '{type(fun)}'")

## Setup constants
verbose, info = kwargs.get('verbose', False), kwargs.get('info', False)
N: int = A.shape[0]
Expand Down Expand Up @@ -366,8 +374,8 @@ def xtrace(
fun: Union[str, Callable] = None,
nv: Union[str, int] = "auto",
pdf: str = "sphere",
atol: float = 0.1,
rtol: float = 1e-6,
atol: float = 0.0,
rtol: float = 0.0,
cond_tol: float = 1e8,
verbose: int = 0,
info: bool = False,
Expand All @@ -378,7 +386,7 @@ def xtrace(
nv = int(nv) if isinstance(nv, Integral) else int(np.ceil(np.sqrt(A.shape[0])))
n = A.shape[0]

## Transparently convert A to matrix function
## If fun is specified, transparently convert A to matrix function
if isinstance(fun, str):
assert fun in _builtin_matrix_functions, "If given as a string, matrix_function be one of the builtin functions."
A = matrix_function(A, fun=fun)
Expand Down Expand Up @@ -413,6 +421,9 @@ def xtrace(
if verbose > 0:
print(f"It: {it}, est: {t:.8f}, Y_size: {Y.shape}, error: {err:.8f}")

if err <= atol:
break

if info:
info = {"estimate": t, "samples": t_samples, "error": err }
return t, info
Expand Down

0 comments on commit 917e85f

Please sign in to comment.