From 4016b0c63d3778d792c0457a1fa3fdffcbd6c038 Mon Sep 17 00:00:00 2001 From: peekxc Date: Wed, 17 Jan 2024 18:20:08 -0500 Subject: [PATCH 1/5] Initial implementation of Hutch++ --- src/primate/_trace.cpp | 17 ++ src/primate/operator.py | 18 +- src/primate/quadrature.py | 98 +++++++++++ src/primate/trace.py | 246 ++++++++++++++------------- tests/test_trace.py | 349 ++++++++++++++++++++++---------------- 5 files changed, 465 insertions(+), 263 deletions(-) create mode 100644 src/primate/quadrature.py diff --git a/src/primate/_trace.cpp b/src/primate/_trace.cpp index c4842d8..7f27f55 100644 --- a/src/primate/_trace.cpp +++ b/src/primate/_trace.cpp @@ -14,10 +14,12 @@ using namespace pybind11::literals; template< typename F > using py_array = py::array_t< F, py::array::f_style | py::array::forcecast >; + // Template function for generating module definitions for a given Operator / precision template< bool multithreaded, std::floating_point F, class Matrix, LinearOperator Wrapper > void _trace_wrapper(py::module& m){ using ArrayF = Eigen::Array< F, Dynamic, 1 >; + using VectorF = Eigen::Matrix< F, Dynamic, 1 >; m.def("hutch", []( const Matrix& A, @@ -60,6 +62,21 @@ void _trace_wrapper(py::module& m){ "matvec_time_us"_a = op.matvec_time ); }); + + // Computes the trace of Q.T @ (A @ Q) including the inner terms q_i^T A q_i + m.def("quad_sum", [](const Matrix& A, DenseMatrix< F > Q) -> py_array< F > { + const auto op = Wrapper(A); + F quad_sum = 0.0; + const size_t N = static_cast< size_t >(Q.cols()); + auto estimates = static_cast< ArrayF >(ArrayF::Zero(N)); + auto y = static_cast< VectorF >(VectorF::Zero(Q.rows())); + for (size_t j = 0; j < N; ++j){ + op.matvec(Q.col(j).data(), y.data()); + estimates[j] = Q.col(j).adjoint().dot(y); + quad_sum += estimates[j]; + } + return py::cast(estimates); + }); } // const Matrix& A, std::function< F(F) > fun, int lanczos_degree, F lanczos_rtol, int _orth, int _ncv diff --git a/src/primate/operator.py b/src/primate/operator.py index 02e0a29..7eba55d 100644 --- a/src/primate/operator.py +++ b/src/primate/operator.py @@ -114,4 +114,20 @@ def _matvec(self, x: np.ndarray) -> np.ndarray: self._z[:len(x)] = x x_fft = np.fft.fft(self._z) y = np.real(np.fft.ifft(self._dfft * x_fft)) - return y[:len(x)] \ No newline at end of file + return y[:len(x)] + + +## For use with e.g. Hutch++ +class OrthComplement(LinearOperator): + def __init__(self, A: Union[LinearOperator, np.ndarray], Q: np.ndarray): + self.A = A + self.Q = Q + self.shape = A.shape + self.dtype = self.Q.dtype + + def _deflate(self, w: np.ndarray) -> np.ndarray: + return w - self.Q @ (self.Q.T @ w) + + def _matvec(self, x: np.ndarray) -> np.ndarray: + y = self._deflate(x) + return self._deflate(self.A @ y) \ No newline at end of file diff --git a/src/primate/quadrature.py b/src/primate/quadrature.py new file mode 100644 index 0000000..aff1100 --- /dev/null +++ b/src/primate/quadrature.py @@ -0,0 +1,98 @@ +from typing import Union, Callable, Any +from numbers import Integral +import numpy as np +from scipy.sparse.linalg import LinearOperator +from scipy.linalg import solve_triangular +from scipy.stats import t +from scipy.special import erfinv +from numbers import Real + +## Package imports +from .random import _engine_prefixes, _engines, isotropic +from .special import _builtin_matrix_functions +from .operator import matrix_function +import _lanczos +import _trace +import _orthogonalize + + +def sl_gauss( + A: Union[LinearOperator, np.ndarray], + n: int = 150, + deg: int = 20, + pdf: str = "rademacher", + rng: str = "pcg", + seed: int = -1, + orth: int = 0, + num_threads: int = 0, +) -> np.ndarray: + """Stochastic Gaussian quadrature approximation. + + Computes a set of sample nodes and weights for the degree-k orthogonal polynomial approximating the + cumulative spectral measure of `A`. This function can be used to approximate the spectral density of `A`, + or to approximate the spectral sum of any function applied to the spectrum of `A`. + + Parameters + ---------- + A : ndarray, sparray, or LinearOperator + real symmetric operator. + n : int, default=150 + Number of random vectors to sample for the quadrature estimate. + deg : int, default=20 + Degree of the quadrature approximation. + rng : { 'splitmix64', 'xoshiro256**', 'pcg64', 'lcg64', 'mt64' }, default="pcg64" + Random number generator to use (PCG64 by default). + seed : int, default=-1 + Seed to initialize the `rng` entropy source. Set `seed` > -1 for reproducibility. + pdf : { 'rademacher', 'normal' }, default="rademacher" + Choice of zero-centered distribution to sample random vectors from. + orth: int, default=0 + Number of additional Lanczos vectors to orthogonalize against when building the Krylov basis. + num_threads: int, default=0 + Number of threads to use to parallelize the computation. Setting `num_threads` < 1 to let OpenMP decide. + + Returns + ------- + trace_estimate : float + Estimate of the trace of the matrix function $f(A)$. + info : dict, optional + If 'info = True', additional information about the computation. + + """ + attr_checks = [hasattr(A, "__matmul__"), hasattr(A, "matmul"), hasattr(A, "dot"), hasattr(A, "matvec")] + assert any(attr_checks), "Invalid operator; must have an overloaded 'matvec' or 'matmul' method" + assert hasattr(A, "shape") and len(A.shape) >= 2, "Operator must be at least two dimensional." + assert A.shape[0] == A.shape[1], "This function only works with square, symmetric matrices!" + + ## Choose the random number engine + assert rng in _engine_prefixes or rng in _engines, f"Invalid pseudo random number engine supplied '{str(rng)}'" + engine_id = _engine_prefixes.index(rng) if rng in _engine_prefixes else _engines.index(rng) + + ## Choose the distribution to sample random vectors from + assert pdf in ["rademacher", "normal"], f"Invalid distribution '{pdf}'; Must be one of 'rademacher' or 'normal'." + distr_id = ["rademacher", "normal"].index(pdf) + + ## Get the dtype; infer it if it's not available + f_dtype = (A @ np.zeros(A.shape[1])).dtype if not hasattr(A, "dtype") else A.dtype + assert ( + f_dtype.type == np.float32 or f_dtype.type == np.float64 + ), "Only 32- or 64-bit floating point numbers are supported." + + ## Extract the machine precision for the given floating point type + lanczos_rtol = np.finfo(f_dtype).eps # if lanczos_rtol is None else f_dtype.type(lanczos_rtol) + + ## Argument checking + m = A.shape[1] # Dimension of the space + nv = int(n) # Number of random vectors to generate + seed = int(seed) # Seed should be an integer + deg = max(deg, 2) # Must be at least 2 + orth = m - 1 if orth < 0 else min(m - 1, orth) # Number of additional vectors should be an integer + ncv = max(int(deg + orth), m) # Number of Lanczos vectors to keep in memory + num_threads = int(num_threads) # should be integer; if <= 0 will trigger max threads on C++ side + + ## Collect the arguments processed so far + sl_quad_args = (nv, distr_id, engine_id, seed, deg, lanczos_rtol, orth, ncv, num_threads) + + ## Make the actual call + quad_nw = _lanczos.stochastic_quadrature(A, *sl_quad_args) + return quad_nw diff --git a/src/primate/trace.py b/src/primate/trace.py index a5850d6..bcabade 100644 --- a/src/primate/trace.py +++ b/src/primate/trace.py @@ -1,4 +1,4 @@ -from typing import Union, Callable +from typing import Union, Callable, Any from numbers import Integral import numpy as np from scipy.sparse.linalg import LinearOperator @@ -13,12 +13,37 @@ from .operator import matrix_function import _lanczos import _trace +import _orthogonalize _default_tvals = t.ppf((0.95 + 1.0) / 2.0, df=np.arange(30)+1) # from collections import namedtuple # HutchParams = namedtuple('HutchParams', ['a', 'b']) +def _operator_checks(A: Any) -> None: + attr_checks = [hasattr(A, "__matmul__"), hasattr(A, "matmul"), hasattr(A, "dot"), hasattr(A, "matvec")] + assert any(attr_checks), "Invalid operator; must have an overloaded 'matvec' or 'matmul' method" + assert hasattr(A, "shape") and len(A.shape) >= 2, "Operator must be at least two dimensional." + assert A.shape[0] == A.shape[1], "This function only works with square, symmetric matrices!" + assert hasattr(A, "shape"), "Operator 'A' must have a valid 'shape' attribute!" + +def _estimator_msg(info) -> str: + msg = f"{info['estimator']} estimator" + msg += f" (fun={info.get('function', None)}" + if info.get('lanczos_kwargs', None) is not None: + msg += f", deg={info['lanczos_kwargs'].get('deg', 20)}" + if info.get('quad', None) is not None: + msg += f", quad={info['quad']}" + msg += ")\n" + msg += f"Est: {info['estimate']:.3f}" + if 'margin_of_error' in info: + moe, conf, cv = (info[k] for k in ['margin_of_error', 'confidence', 'coeff_var']) + msg += f" +/- {moe:.3f} ({conf*100:.0f}% CI | {(cv*100):.0f}% CV)" + msg += f", (#S:{ info['n_samples'] } | #MV:{ info['n_matvecs']}) [{info['pdf'][0].upper()}]" + if info.get('seed', -1) != -1: + msg += f" (seed: {info['seed']})" + return msg + def hutch( A: Union[LinearOperator, np.ndarray], fun: Union[str, Callable] = None, @@ -51,7 +76,7 @@ def hutch( :::{.callout-note} Convergence behavior is controlled by the `stop` parameter: "confidence" uses the central limit theorem to generate confidence intervals on the fly, which may be used in conjunction with `atol` and `rtol` to upper-bound the error of the approximation. - Alternatively, when `stop` = "change", the estimator is considered converged when the error between the last two iterates is less than + Alternatively, when `stop` = 'change', the estimator is considered converged when the error between the last two iterates is less than `atol` (or `rtol`, respectively), similar to the behavior of scipy.integrate.quadrature. ::: @@ -101,10 +126,11 @@ def hutch( Notes ----- - To compute the weights of the quadrature, `quad` can be set to either 'golub_welsch' or 'fttr'. The former (GW) uses implicit symmetric QR steps with Wilkinson shifts, - while the latter (FTTR) uses the explicit expression for orthogonal polynomials. While both require $O(\\mathrm{deg}^2)$ time to execute, - the former requires $O(\\mathrm{deg}^2)$ space but is highly accurate, while the latter uses only $O(1)$ space at the cost of stability. - If `deg` is large, `fttr` is preferred. + To compute the weights of the quadrature, `quad` can be set to either 'golub_welsch' or 'fttr'. The former uses implicit symmetric QR + steps with Wilkinson shifts, while the latter (FTTR) uses the explicit expression for orthogonal polynomials. While both require + $O(\\mathrm{deg}^2)$ time to execute, the former requires $O(\\mathrm{deg}^2)$ space but is highly accurate, while the latter uses + only $O(1)$ space at the cost of stability. If `deg` is large, `fttr` is preferred is performance, though pilot testing should be + done to ensure that instability does not cause bias in the approximation. See Also -------- @@ -115,13 +141,10 @@ def hutch( 1. Ubaru, S., Chen, J., & Saad, Y. (2017). Fast estimation of tr(f(A)) via stochastic Lanczos quadrature. SIAM Journal on Matrix Analysis and Applications, 38(4), 1075-1099. 2. Hutchinson, Michael F. "A stochastic estimator of the trace of the influence matrix for Laplacian smoothing splines." Communications in Statistics-Simulation and Computation 18.3 (1989): 1059-1076. """ - attr_checks = [hasattr(A, "__matmul__"), hasattr(A, "matmul"), hasattr(A, "dot"), hasattr(A, "matvec")] - assert any(attr_checks), "Invalid operator; must have an overloaded 'matvec' or 'matmul' method" - assert hasattr(A, "shape") and len(A.shape) >= 2, "Operator must be at least two dimensional." - assert A.shape[0] == A.shape[1], "This function only works with square, symmetric matrices!" - + ## Quick + basic input validation checks + _operator_checks(A) + ## Catch degenerate cases - assert hasattr(A, "shape"), "Operator 'A' must have a valid 'shape' attribute!" if (np.prod(A.shape) == 0) or (np.sum(A.shape) == 0): return 0 @@ -190,33 +213,19 @@ def hutch( ## Make the actual call info_dict = _trace.hutch(A, *hutch_args, **kwargs) - ## Print the status if requested - if verbose: - msg = f"Girard-Hutchinson estimator (fun={kwargs['function']}, deg={deg}, quad={quad})\n" - valid_samples = info_dict['samples'] != 0 - n_valid = sum(valid_samples) - std_error = np.std(info_dict['samples'][valid_samples], ddof=1) / np.sqrt(n_valid) - z = np.sqrt(2.0) * erfinv(confidence) - cv = np.abs(std_error / info_dict['estimate']) - msg += f"Est: {info_dict['estimate']:.3f} +/- {z * std_error:.2f} ({confidence*100:.0f}% CI), CV: {(cv*100):.0f}%, " - msg += f"Evals: { n_valid } [{pdf[0].upper()}]" - if seed != -1: - msg += f" (seed: {seed})" - print(msg) - - ## If only the point-estimate is required, return it - if not info and not plot: + ## Return as early as possible if no additional info requested for speed + if not verbose and not info and not plot: return info_dict["estimate"] - - ## Otherwise build the info - if plot: - from bokeh.plotting import show - from .plotting import figure_trace - p = figure_trace(info_dict["samples"]) - show(p) - info_dict['figure'] = figure_trace(info_dict["samples"]) - - ## Build the info dictionary + + ## Post-process info dict + info_dict['estimator'] = "Girard-Hutchinson" + info_dict['valid'] = info_dict['samples'] != 0 + info_dict['n_samples'] = np.sum(info_dict['valid']) + info_dict['n_matvecs'] = info_dict['n_samples'] * deg + info_dict['std_error'] = np.std(info_dict['samples'][info_dict['valid']], ddof=1) / np.sqrt(info_dict['n_samples']) + info_dict['coeff_var'] = np.abs(info_dict['std_error'] / info_dict['estimate']) + info_dict['margin_of_error'] = (t_values[info_dict['n_samples']] if info_dict['n_samples'] < 30 else z) * info_dict['std_error'] + info_dict['confidence'] = confidence info_dict["stop"] = stop info_dict["pdf"] = pdf info_dict["rng"] = _engines[engine_id] @@ -228,94 +237,95 @@ def hutch( info_dict["atol"] = atol info_dict["num_threads"] = "auto" if num_threads == 0 else num_threads info_dict["maxiter"] = nv - info_dict["confidence"] = confidence - return info_dict["estimate"], info_dict + ## Print the status if requested + if verbose: + print(_estimator_msg(info_dict)) -# TODO: implement hutch++ -# def hutchpp(): - -def sl_gauss( - A: Union[LinearOperator, np.ndarray], - n: int = 150, - deg: int = 20, - pdf: str = "rademacher", - rng: str = "pcg", - seed: int = -1, - orth: int = 0, - num_threads: int = 0, -) -> np.ndarray: - """Stochastic Gaussian quadrature approximation. - - Computes a set of sample nodes and weights for the degree-k orthogonal polynomial approximating the - cumulative spectral measure of `A`. This function can be used to approximate the spectral density of `A`, - or to approximate the spectral sum of any function applied to the spectrum of `A`. - - Parameters - ---------- - A : ndarray, sparray, or LinearOperator - real symmetric operator. - n : int, default=150 - Number of random vectors to sample for the quadrature estimate. - deg : int, default=20 - Degree of the quadrature approximation. - rng : { 'splitmix64', 'xoshiro256**', 'pcg64', 'lcg64', 'mt64' }, default="pcg64" - Random number generator to use (PCG64 by default). - seed : int, default=-1 - Seed to initialize the `rng` entropy source. Set `seed` > -1 for reproducibility. - pdf : { 'rademacher', 'normal' }, default="rademacher" - Choice of zero-centered distribution to sample random vectors from. - orth: int, default=0 - Number of additional Lanczos vectors to orthogonalize against when building the Krylov basis. - num_threads: int, default=0 - Number of threads to use to parallelize the computation. Setting `num_threads` < 1 to let OpenMP decide. + ## Plot samples if requested + if plot: + from bokeh.plotting import show + from .plotting import figure_trace + p = figure_trace(info_dict["samples"]) + show(p) + info_dict['figure'] = figure_trace(info_dict["samples"]) - Returns - ------- - trace_estimate : float - Estimate of the trace of the matrix function $f(A)$. - info : dict, optional - If 'info = True', additional information about the computation. + ## Final return + return (info_dict["estimate"], info_dict) if info else info_dict["estimate"] - """ - attr_checks = [hasattr(A, "__matmul__"), hasattr(A, "matmul"), hasattr(A, "dot"), hasattr(A, "matvec")] - assert any(attr_checks), "Invalid operator; must have an overloaded 'matvec' or 'matmul' method" - assert hasattr(A, "shape") and len(A.shape) >= 2, "Operator must be at least two dimensional." - assert A.shape[0] == A.shape[1], "This function only works with square, symmetric matrices!" - ## Choose the random number engine - assert rng in _engine_prefixes or rng in _engines, f"Invalid pseudo random number engine supplied '{str(rng)}'" - engine_id = _engine_prefixes.index(rng) if rng in _engine_prefixes else _engines.index(rng) +def hutchpp( + A: Union[LinearOperator, np.ndarray], + fun: Union[str, Callable] = None, + b: int = "auto", + maxiter: int = 200, + mode: str = 'reduced', + **kwargs +) -> Union[float, dict]: + _operator_checks(A) - ## Choose the distribution to sample random vectors from - assert pdf in ["rademacher", "normal"], f"Invalid distribution '{pdf}'; Must be one of 'rademacher' or 'normal'." - distr_id = ["rademacher", "normal"].index(pdf) + ## Catch degenerate cases + if (np.prod(A.shape) == 0) or (np.sum(A.shape) == 0): + return 0 - ## Get the dtype; infer it if it's not available + ## Setup constants + verbose, info = kwargs.get('verbose', False), kwargs.get('info', False) + N: int = A.shape[0] + b = (N // 3) if b == "auto" else b + m = (N // 3) if maxiter == "auto" else maxiter + assert m % 3 == 0, "Number of sample vectors 'm' must be divisible by 3." f_dtype = (A @ np.zeros(A.shape[1])).dtype if not hasattr(A, "dtype") else A.dtype - assert ( - f_dtype.type == np.float32 or f_dtype.type == np.float64 - ), "Only 32- or 64-bit floating point numbers are supported." - - ## Extract the machine precision for the given floating point type - lanczos_rtol = np.finfo(f_dtype).eps # if lanczos_rtol is None else f_dtype.type(lanczos_rtol) - - ## Argument checking - m = A.shape[1] # Dimension of the space - nv = int(n) # Number of random vectors to generate - seed = int(seed) # Seed should be an integer - deg = max(deg, 2) # Must be at least 2 - orth = m - 1 if orth < 0 else min(m - 1, orth) # Number of additional vectors should be an integer - ncv = max(int(deg + orth), m) # Number of Lanczos vectors to keep in memory - num_threads = int(num_threads) # should be integer; if <= 0 will trigger max threads on C++ side - - ## Collect the arguments processed so far - sl_quad_args = (nv, distr_id, engine_id, seed, deg, lanczos_rtol, orth, ncv, num_threads) - - ## Make the actual call - quad_nw = _lanczos.stochastic_quadrature(A, *sl_quad_args) - return quad_nw - + info_dict = {} + + ## Raw isotropic random vectors + # W = np.random.choice([-1.0, +1.0], size=(N, M)).astype(f_dtype) + # W1, W2 = W[:,:(m // 3)], W[:,(m // 3):] + WB = np.random.choice([-1.0, +1.0], size=(N, b)).astype(f_dtype) + + ## Sketch Y - use numpy for now, but consider parallelizing MGS later + Q = np.linalg.qr(A @ WB)[0] + # Y = np.array(A @ W2, order='F') + # assert Y.flags['F_CONTIGUOUS'] and Y.flags['OWNDATA'] and Y.flags['WRITEABLE'] + # _orthogonalize.mgs(Y, 0) + # Q = Y # Q is mostly orthonormal + + ## Estimate trace of the low-rank approx. / sketch + bulk_tr = 0.0 + if mode == 'full': + bulk_tr = (Q.T @ (A @ Q)).trace() + else: + bulk_tr = np.sum(_trace.quad_sum(A, Q)) + + ## Estimate trace of the residual + residual_tr = 0.0 + if mode == 'full': + WM = np.random.choice([-1.0, +1.0], size=(N, m)).astype(f_dtype) + G = WM - Q @ (Q.T @ WM) + residual_tr += (1 / m) * (G.T @ (A @ G)).trace() + else: + from primate.operator import OrthComplement + PC = OrthComplement(A, Q) # evaluates (I - Q^T Q)A(I - Q^T Q) + kwargs['maxiter'] = kwargs.get('maxiter', m) + if not info and not verbose: + residual_tr = hutch(PC, **kwargs) + return bulk_tr + residual_tr + else: + kwargs['info'] = True + kwargs['verbose'] = False + residual_tr, ID = hutch(PC, **kwargs) + info_dict.update(ID) + + ## Modify the info dict + info_dict['estimate'] = bulk_tr + residual_tr + info_dict['estimator'] = 'Hutch++' + info_dict['n_matvecs'] = 2*b + info_dict['n_samples'] + info_dict['n_samples'] = b + info_dict['n_samples'] + info_dict['pdf'] = 'rademacher' + + ## Print as needed + if verbose: + print(_estimator_msg(info_dict)) + return info_dict['estimate'] if not info else (info_dict['estimate'], info_dict) def __xtrace(W: np.ndarray, Z: np.ndarray, Q: np.ndarray, R: np.ndarray, method: str): """Helper for xtrace function""" diff --git a/tests/test_trace.py b/tests/test_trace.py index ccceb26..e3d1cb8 100644 --- a/tests/test_trace.py +++ b/tests/test_trace.py @@ -15,167 +15,228 @@ ## NOTE: trace estimation only works with isotropic vectors def test_girard_fixed(): - from sanity import girard_hutch - np.random.seed(1234) - n = 30 - ew = 0.2 + 1.5*np.linspace(0, 5, n) - Q,R = np.linalg.qr(np.random.uniform(size=(n,n))) - A = Q @ np.diag(ew) @ Q.T - A = (A + A.T) / 2 - ew_true = np.linalg.eigvalsh(A) - tr_est = girard_hutch(A, lambda x: x, nv = n, estimates=False) - threshold = 0.05*(np.max(ew)*n - np.min(ew)*n) - assert np.allclose(ew_true, ew) - assert np.isclose(A.trace() - tr_est, 0.0, atol=threshold) + from sanity import girard_hutch + np.random.seed(1234) + n = 30 + ew = 0.2 + 1.5*np.linspace(0, 5, n) + Q,R = np.linalg.qr(np.random.uniform(size=(n,n))) + A = Q @ np.diag(ew) @ Q.T + A = (A + A.T) / 2 + ew_true = np.linalg.eigvalsh(A) + tr_est = girard_hutch(A, lambda x: x, nv = n, estimates=False) + threshold = 0.05*(np.max(ew)*n - np.min(ew)*n) + assert np.allclose(ew_true, ew) + assert np.isclose(A.trace() - tr_est, 0.0, atol=threshold) def test_trace_import(): - import primate.trace - assert '_trace' in dir(primate.trace) - from primate.trace import hutch, _trace - assert 'hutch' in dir(_trace) - assert isinstance(hutch, Callable) + import primate.trace + assert '_trace' in dir(primate.trace) + from primate.trace import hutch, _trace + assert 'hutch' in dir(_trace) + assert isinstance(hutch, Callable) def test_trace_basic(): - from primate.trace import hutch - np.random.seed(1234) - n = 10 - A = symmetric(n) - tr_test1 = hutch(A, maxiter=100, seed=5, num_threads=1) - tr_test2 = hutch(A, maxiter=100, seed=5, num_threads=1) - tr_true = A.trace() - assert tr_test1 == tr_test2, "Builds not reproducible!" - assert np.isclose(tr_test1, tr_true, atol=tr_true*0.05) + from primate.trace import hutch + np.random.seed(1234) + n = 10 + A = symmetric(n) + tr_test1 = hutch(A, maxiter=100, seed=5, num_threads=1) + tr_test2 = hutch(A, maxiter=100, seed=5, num_threads=1) + tr_true = A.trace() + assert tr_test1 == tr_test2, "Builds not reproducible!" + assert np.isclose(tr_test1, tr_true, atol=tr_true*0.05) def test_trace_pdfs(): - from primate.trace import hutch - np.random.seed(1234) - n = 50 - A = symmetric(n) - tr_test1 = hutch(A, maxiter=200, seed=5, num_threads=1, pdf="rademacher") - tr_test2 = hutch(A, maxiter=200, seed=5, num_threads=1, pdf="normal") - tr_true = A.trace() - assert np.isclose(tr_test1, tr_test2, atol=tr_true*0.05) + from primate.trace import hutch + np.random.seed(1234) + n = 50 + A = symmetric(n) + tr_test1 = hutch(A, maxiter=200, seed=5, num_threads=1, pdf="rademacher") + tr_test2 = hutch(A, maxiter=200, seed=5, num_threads=1, pdf="normal") + tr_true = A.trace() + assert np.isclose(tr_test1, tr_test2, atol=tr_true*0.05) def test_trace_inputs(): - from primate.trace import hutch - n = 10 - A = symmetric(n) - tr_1 = hutch(A, maxiter=100) - tr_2 = hutch(csc_array(A), maxiter=100) - tr_3 = hutch(aslinearoperator(A), maxiter=100) - assert all([isinstance(t, Number) for t in [tr_1, tr_2, tr_3]]) + from primate.trace import hutch + n = 10 + A = symmetric(n) + tr_1 = hutch(A, maxiter=100) + tr_2 = hutch(csc_array(A), maxiter=100) + tr_3 = hutch(aslinearoperator(A), maxiter=100) + assert all([isinstance(t, Number) for t in [tr_1, tr_2, tr_3]]) def test_hutch_info(): - from primate.trace import hutch - np.random.seed(1234) - n = 25 - A = csc_array(symmetric(n), dtype=np.float32) - tr_est, info = hutch(A, maxiter = 200, info=True) - assert isinstance(info, dict) and isinstance(tr_est, Number) - assert len(info['samples']) == 200 - assert np.all(~np.isclose(info['samples'], 0.0)) - assert np.isclose(tr_est, A.trace(), atol=1.0) + from primate.trace import hutch + np.random.seed(1234) + n = 25 + A = csc_array(symmetric(n), dtype=np.float32) + tr_est, info = hutch(A, maxiter = 200, info=True) + assert isinstance(info, dict) and isinstance(tr_est, Number) + assert len(info['samples']) == 200 + assert np.all(~np.isclose(info['samples'], 0.0)) + assert np.isclose(tr_est, A.trace(), atol=1.0) def test_hutch_multithread(): - from primate.trace import hutch - np.random.seed(1234) - n = 25 - A = csc_array(symmetric(n), dtype=np.float32) - tr_est, info = hutch(A, maxiter = 200, atol=0.0, info = True, num_threads=6) - assert len(info['samples'] == 200) - assert np.all(~np.isclose(info['samples'], 0.0)) - assert np.isclose(tr_est, A.trace(), atol=1.0) + from primate.trace import hutch + np.random.seed(1234) + n = 25 + A = csc_array(symmetric(n), dtype=np.float32) + tr_est, info = hutch(A, maxiter = 200, atol=0.0, info = True, num_threads=6) + assert len(info['samples'] == 200) + assert np.all(~np.isclose(info['samples'], 0.0)) + assert np.isclose(tr_est, A.trace(), atol=1.0) def test_hutch_clt_atol(): - from primate.trace import hutch - np.random.seed(1234) - n = 30 - A = csc_array(symmetric(n), dtype=np.float32) - - from primate.stats import sample_mean_cinterval - tr_est, info = hutch(A, maxiter = 100, num_threads=1, seed=5, info=True) - tr_samples = info['samples'] - ci = np.array([sample_mean_cinterval(tr_samples[:i], sdist='normal') if i > 1 else [-np.inf, np.inf] for i in range(len(tr_samples))]) - - ## Detect when, for the fixed set of samples, the trace estimator should converge by CLT - atol_threshold = (A.trace() * 0.05) - clt_converged = np.ravel(0.5*np.diff(ci, axis=1)) <= atol_threshold - assert np.any(clt_converged), "Did not converge!" - converged_ind = np.flatnonzero(clt_converged)[0] - - ## Re-run with same seed and ensure the index matches - tr_est, info = hutch(A, maxiter = 100, num_threads=1, atol=atol_threshold, seed=5, info=True) - tr_samples = info['samples'] - converged_online = np.take(np.flatnonzero(tr_samples == 0.0), 0) - assert converged_online == converged_ind, "hutch not converging at correct index!" + from primate.trace import hutch + np.random.seed(1234) + n = 30 + A = csc_array(symmetric(n), dtype=np.float32) + + from primate.stats import sample_mean_cinterval + tr_est, info = hutch(A, maxiter = 100, num_threads=1, seed=5, info=True) + tr_samples = info['samples'] + ci = np.array([sample_mean_cinterval(tr_samples[:i], sdist='normal') if i > 1 else [-np.inf, np.inf] for i in range(len(tr_samples))]) + + ## Detect when, for the fixed set of samples, the trace estimator should converge by CLT + atol_threshold = (A.trace() * 0.05) + clt_converged = np.ravel(0.5*np.diff(ci, axis=1)) <= atol_threshold + assert np.any(clt_converged), "Did not converge!" + converged_ind = np.flatnonzero(clt_converged)[0] + + ## Re-run with same seed and ensure the index matches + tr_est, info = hutch(A, maxiter = 100, num_threads=1, atol=atol_threshold, seed=5, info=True) + tr_samples = info['samples'] + converged_online = np.take(np.flatnonzero(tr_samples == 0.0), 0) + assert converged_online == converged_ind, "hutch not converging at correct index!" def test_hutch_change(): - from primate.trace import hutch - np.random.seed(1234) - n = 30 - A = csc_array(symmetric(n), dtype=np.float32) - tr_est, info = hutch(A, maxiter = 100, num_threads=1, seed=5, info=True) - tr_samples = info['samples'] - estimator = np.cumsum(tr_samples) / np.arange(1, 101) - conv_ind_true = np.flatnonzero(np.abs(np.diff(estimator)) <= 0.001)[0] + 1 - - ## Test the convergence checking for the atol change method - tr_est, info = hutch(A, maxiter = 100, num_threads=1, seed=5, info=True, atol=0.001, stop="change") - conv_ind_test = np.take(np.flatnonzero(info['samples'] == 0), 0) - assert abs(conv_ind_true - conv_ind_test) <= 1 + from primate.trace import hutch + np.random.seed(1234) + n = 30 + A = csc_array(symmetric(n), dtype=np.float32) + tr_est, info = hutch(A, maxiter = 100, num_threads=1, seed=5, info=True) + tr_samples = info['samples'] + estimator = np.cumsum(tr_samples) / np.arange(1, 101) + conv_ind_true = np.flatnonzero(np.abs(np.diff(estimator)) <= 0.001)[0] + 1 + + ## Test the convergence checking for the atol change method + tr_est, info = hutch(A, maxiter = 100, num_threads=1, seed=5, info=True, atol=0.001, stop="change") + conv_ind_test = np.take(np.flatnonzero(info['samples'] == 0), 0) + assert abs(conv_ind_true - conv_ind_test) <= 1 def test_trace_mf(): - from primate.trace import hutch - n = 10 - A = symmetric(n) - tr_est = hutch(A, fun="identity", maxiter=100, num_threads=1, seed = 5) - tr_true = A.trace() - assert np.isclose(tr_est, tr_true, atol=tr_true*0.05) - tr_est = hutch(A, fun=lambda x: x, maxiter=100, num_threads=1, seed = 5) - assert np.isclose(tr_est, tr_true, atol=tr_true*0.05) + from primate.trace import hutch + n = 10 + A = symmetric(n) + tr_est = hutch(A, fun="identity", maxiter=100, num_threads=1, seed = 5) + tr_true = A.trace() + assert np.isclose(tr_est, tr_true, atol=tr_true*0.05) + tr_est = hutch(A, fun=lambda x: x, maxiter=100, num_threads=1, seed = 5) + assert np.isclose(tr_est, tr_true, atol=tr_true*0.05) def test_trace_fftr(): - from primate.trace import hutch - n = 50 - A = symmetric(n) - - ## Test the fttr against the golub_welsch - tr_est1, info1 = hutch(A, fun="identity", maxiter=100, seed=5, num_threads=1, info=True, quad="golub_welsch") - tr_est2, info2 = hutch(A, fun="identity", maxiter=100, seed=5, num_threads=1, info=True, quad="fttr") - assert np.isclose(tr_est1, tr_est2, atol=tr_est1*0.01) - - ## Test accuracy - assert np.isclose(A.trace(), tr_est1, atol=tr_est1*0.025) - - # from primate.diagonalize import lanczos, _lanczos - # v0 = np.array([-1, 1, 1,-1, 1,-1, 1,-1,-1, 1]) / np.sqrt(10) - # a, b = lanczos(A, v0=v0, deg=10) - # a, b = a, np.append([0], b) - # _lanczos.quadrature(a, b, 10, 0) - - - # from primate.operator import matrix_function - # M = matrix_function(A, fun="identity") - # M.method = "fttr" - # M.quad(np.random.choice([-1.0, +1.0], size=n)) - - ## TODO - # tr_est, info = hutch(A, fun=lambda x: x, maxiter=100, seed=5, num_threads=1, info=True) - # assert np.isclose(tr_est, tr_true, atol=tr_true*0.05) - # for s in range(15000): - # est, info = hutch(A, fun="identity", deg=2, maxiter=200, num_threads=1, seed=591, info=True) - # assert not np.isnan(est) - # for s in range(15000): - # est, info = hutch(A, fun="identity", deg=20, maxiter=200, num_threads=8, seed=-1, info=True) - # assert not np.isnan(est) - - # from primate.operator import matrix_function - # M = matrix_function(A, fun="identity", deg=20) - # for s in range(15000): - # v0 = np.random.choice([-1, 1], size=M.shape[0]) - # assert not np.isnan(M.quad(v0)) - - - # from primate.diagonalize import lanczos - # lanczos(A, v0=v0, rtol=M.rtol, deg=M.deg, orth=M.orth) - # if np.any(np.isnan(info['samples'])) \ No newline at end of file + from primate.trace import hutch + n = 50 + A = symmetric(n) + + ## Test the fttr against the golub_welsch + tr_est1, info1 = hutch(A, fun="identity", maxiter=100, seed=5, num_threads=1, info=True, quad="golub_welsch") + tr_est2, info2 = hutch(A, fun="identity", maxiter=100, seed=5, num_threads=1, info=True, quad="fttr") + assert np.isclose(tr_est1, tr_est2, atol=tr_est1*0.01) + + ## Test accuracy + assert np.isclose(A.trace(), tr_est1, atol=tr_est1*0.025) + + # from primate.diagonalize import lanczos, _lanczos + # v0 = np.array([-1, 1, 1,-1, 1,-1, 1,-1,-1, 1]) / np.sqrt(10) + # a, b = lanczos(A, v0=v0, deg=10) + # a, b = a, np.append([0], b) + # _lanczos.quadrature(a, b, 10, 0) + + + # from primate.operator import matrix_function + # M = matrix_function(A, fun="identity") + # M.method = "fttr" + # M.quad(np.random.choice([-1.0, +1.0], size=n)) + + ## TODO + # tr_est, info = hutch(A, fun=lambda x: x, maxiter=100, seed=5, num_threads=1, info=True) + # assert np.isclose(tr_est, tr_true, atol=tr_true*0.05) + # for s in range(15000): + # est, info = hutch(A, fun="identity", deg=2, maxiter=200, num_threads=1, seed=591, info=True) + # assert not np.isnan(est) + # for s in range(15000): + # est, info = hutch(A, fun="identity", deg=20, maxiter=200, num_threads=8, seed=-1, info=True) + # assert not np.isnan(est) + + # from primate.operator import matrix_function + # M = matrix_function(A, fun="identity", deg=20) + # for s in range(15000): + # v0 = np.random.choice([-1, 1], size=M.shape[0]) + # assert not np.isnan(M.quad(v0)) + + + # from primate.diagonalize import lanczos + # lanczos(A, v0=v0, rtol=M.rtol, deg=M.deg, orth=M.orth) + # if np.any(np.isnan(info['samples'])) + +def test_quad_sum(): + from primate.trace import _trace + np.random.seed(1234) + n = 100 + A = symmetric(n) + Q = np.linalg.qr(A)[0] + test_quads = _trace.quad_sum(A, Q) + true_quads = (Q.T @ (A @ Q)).diagonal() + assert np.allclose(test_quads, true_quads) + assert np.isclose(np.sum(test_quads), np.sum(true_quads)) + test_quads = _trace.quad_sum(A, Q[:,:10]) + true_quads = (Q[:,:10].T @ (A @ Q[:,:10])).diagonal() + assert np.allclose(test_quads, true_quads) + assert np.isclose(np.sum(test_quads), np.sum(true_quads)) + +# def test_hutch_pp(): +# from primate.trace import hutchpp +# np.random.seed(1234) +# n = 100 +# A = symmetric(n) +# hutchpp(A, mode="reduced", verbose=True) + + # ## Raw isotropic random vectors + # m = 20 * 3 + # N: int = A.shape[0] + # M: int = 2 * m // 3 + # f_dtype = (A @ np.zeros(A.shape[1])).dtype if not hasattr(A, "dtype") else A.dtype + # W = np.random.choice([-1.0, +1.0], size=(N, M)).astype(f_dtype) + # W1, W2 = W[:,:(m // 3)], W[:,(m // 3):] + # Q = np.linalg.qr(A @ W2)[0] + + # ## Start with estimate using the largest eigen-spaces + # bulk_tr = (Q.T @ (A @ Q)).trace() + + # ## Estimate residual via Girard + # # from scipy.sparse.linalg import aslinearoperator, LinearOperator + # # B = np.eye(A.shape[0]) - Q @ Q.T + # # deflate_proj = lambda w: B @ (A @ (B @ w)) + # # L = LinearOperator(matvec = deflate_proj, shape=A.shape) + # from primate.operator import OrthComplement + # L = OrthComplement(A, Q) + + # from primate.trace import hutch + # true_residual = A.trace() - bulk_tr + # np.abs((bulk_tr + hutch(L, atol = 0.10, maxiter=1000, verbose=True)) - A.trace()) + # hutch(L) + # hutch(L, atol = 0.001, maxiter=200, deg=40, verbose=True, plot=True) + # 0.9534050697745574 - 0.003 <= true_residual + # true_residual <= 0.9534050697745574 + 0.003 + + # residual_tr = 0.0 + # if True: + # G = W1 - Q @ Q.T @ W1 + # residual_tr = (1 / (m // 3)) * (G.T @ (A @ G)).trace() + + # print(f"A trace: {A.trace():8f}") + # for i in range(1, A.shape[0]): + # np.random.seed(1234) + # est = hutch_pp(A, m = 3 * i) + # print(f"{i}: {est} (error: {np.abs(est - A.trace())})") \ No newline at end of file From eab759183b1bef341419d79c3413d6c9a62619a1 Mon Sep 17 00:00:00 2001 From: peekxc Date: Fri, 19 Jan 2024 14:22:27 -0500 Subject: [PATCH 2/5] Re-organizing + adding CentOS build to cirrus for discovery --- .cirrus.yml | 4 +++- src/primate/functional.py | 1 + src/primate/meson.build | 1 + src/primate/trace.py | 25 +++++++++++++------------ tests/test_slq.py | 2 +- tests/test_trace.py | 20 ++++++++++++++------ 6 files changed, 33 insertions(+), 20 deletions(-) diff --git a/.cirrus.yml b/.cirrus.yml index 23306f8..f2e5c3f 100644 --- a/.cirrus.yml +++ b/.cirrus.yml @@ -1,10 +1,12 @@ env: CIRRUS_CLONE_SUBMODULES: true +# image: quay.io/pypa/manylinux_2_24_x86_64 # debian linux_task: container: - image: quay.io/pypa/manylinux_2_28_x86_64 + # image: quay.io/pypa/manylinux_2_28_x86_64 + image: quay.io/pypa/manylinux2014_x86_64 only_if: changesInclude('.cirrus.yml', '**.{h,cpp,py}') env: CC: clang diff --git a/src/primate/functional.py b/src/primate/functional.py index b8fd081..3164288 100644 --- a/src/primate/functional.py +++ b/src/primate/functional.py @@ -10,6 +10,7 @@ from .diagonalize import lanczos ## Since Python has no support for inline creation of sized generators +## Based on "Estimating the Largest Eigenvalue by the Power and Lanczos Algorithms with a Random Start" class RelativeErrorBound(): def __init__(self, n: int): self.base_num = 2.575 * np.log(n) diff --git a/src/primate/meson.build b/src/primate/meson.build index 144eb99..6fb3283 100644 --- a/src/primate/meson.build +++ b/src/primate/meson.build @@ -27,6 +27,7 @@ python_sources = [ 'stats.py', 'special.py', 'functional.py', + 'quadrature.py', '__init__.py' ] diff --git a/src/primate/trace.py b/src/primate/trace.py index bcabade..d85dd77 100644 --- a/src/primate/trace.py +++ b/src/primate/trace.py @@ -10,7 +10,7 @@ ## Package imports from .random import _engine_prefixes, _engines, isotropic from .special import _builtin_matrix_functions -from .operator import matrix_function +from .operator import matrix_function, OrthComplement import _lanczos import _trace import _orthogonalize @@ -262,6 +262,8 @@ def hutchpp( mode: str = 'reduced', **kwargs ) -> Union[float, dict]: + """Hutch++ estimator. + """ _operator_checks(A) ## Catch degenerate cases @@ -271,18 +273,14 @@ def hutchpp( ## Setup constants verbose, info = kwargs.get('verbose', False), kwargs.get('info', False) N: int = A.shape[0] - b = (N // 3) if b == "auto" else b - m = (N // 3) if maxiter == "auto" else maxiter - assert m % 3 == 0, "Number of sample vectors 'm' must be divisible by 3." + b = (N // 3) if b == "auto" else b # main samples + m = (N // 3) if maxiter == "auto" else maxiter # residual samples + # assert m % 3 == 0, "Number of sample vectors 'm' must be divisible by 3." f_dtype = (A @ np.zeros(A.shape[1])).dtype if not hasattr(A, "dtype") else A.dtype info_dict = {} - ## Raw isotropic random vectors - # W = np.random.choice([-1.0, +1.0], size=(N, M)).astype(f_dtype) - # W1, W2 = W[:,:(m // 3)], W[:,(m // 3):] + ## Sketch Y / Q - use numpy for now, but consider parallelizing MGS later WB = np.random.choice([-1.0, +1.0], size=(N, b)).astype(f_dtype) - - ## Sketch Y - use numpy for now, but consider parallelizing MGS later Q = np.linalg.qr(A @ WB)[0] # Y = np.array(A @ W2, order='F') # assert Y.flags['F_CONTIGUOUS'] and Y.flags['OWNDATA'] and Y.flags['WRITEABLE'] @@ -299,11 +297,14 @@ def hutchpp( ## Estimate trace of the residual residual_tr = 0.0 if mode == 'full': + ## Full mode == form the full (m x m) matrix and take the diagonal + ## Note memory efficient, but is vectorized, so suitable for relatively small m WM = np.random.choice([-1.0, +1.0], size=(N, m)).astype(f_dtype) G = WM - Q @ (Q.T @ WM) residual_tr += (1 / m) * (G.T @ (A @ G)).trace() else: - from primate.operator import OrthComplement + ## reduced mode == Switch to plain Hutch estimator on orthogonal complement projector + ## Low memory footprint but might be 5x slower or worse on small inputs PC = OrthComplement(A, Q) # evaluates (I - Q^T Q)A(I - Q^T Q) kwargs['maxiter'] = kwargs.get('maxiter', m) if not info and not verbose: @@ -318,8 +319,8 @@ def hutchpp( ## Modify the info dict info_dict['estimate'] = bulk_tr + residual_tr info_dict['estimator'] = 'Hutch++' - info_dict['n_matvecs'] = 2*b + info_dict['n_samples'] - info_dict['n_samples'] = b + info_dict['n_samples'] + info_dict['n_matvecs'] = 2*b + info_dict.get('n_samples', m) + info_dict['n_samples'] = b + info_dict.get('n_samples', m) info_dict['pdf'] = 'rademacher' ## Print as needed diff --git a/tests/test_slq.py b/tests/test_slq.py index a684309..9e58fc7 100644 --- a/tests/test_slq.py +++ b/tests/test_slq.py @@ -15,7 +15,7 @@ def test_stochastic_quadrature(): from primate.diagonalize import _lanczos assert hasattr(_lanczos, "stochastic_quadrature"), "Module compile failed" - from primate.trace import sl_gauss + from primate.quadrature import sl_gauss np.random.seed(1234) n, nv, lanczos_deg = 30, 250, 20 A = csc_array(symmetric(30), dtype=np.float32) diff --git a/tests/test_trace.py b/tests/test_trace.py index e3d1cb8..8c67243 100644 --- a/tests/test_trace.py +++ b/tests/test_trace.py @@ -195,12 +195,20 @@ def test_quad_sum(): assert np.allclose(test_quads, true_quads) assert np.isclose(np.sum(test_quads), np.sum(true_quads)) -# def test_hutch_pp(): -# from primate.trace import hutchpp -# np.random.seed(1234) -# n = 100 -# A = symmetric(n) -# hutchpp(A, mode="reduced", verbose=True) +def test_hutch_pp(): + from primate.trace import hutchpp + np.random.seed(1234) + n = 100 + A = symmetric(n) + test_trace = hutchpp(A, mode="reduced", seed=1) + true_trace = A.trace() + assert np.isclose(test_trace, true_trace, atol=1.0) + test_trace = hutchpp(A, mode="full", seed=1) + assert np.isclose(test_trace, true_trace, atol=1.0) + + # import timeit + # timeit.timeit(lambda: hutchpp(A, mode="full"), number=1000) + # timeit.timeit(lambda: hutchpp(A, mode="reduced"), number=1000) # ## Raw isotropic random vectors # m = 20 * 3 From 3872ad78a00f836ea9a34bd85a905f680f960bc9 Mon Sep 17 00:00:00 2001 From: peekxc Date: Fri, 19 Jan 2024 14:40:48 -0500 Subject: [PATCH 3/5] cirrus changes --- .cirrus.yml | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/.cirrus.yml b/.cirrus.yml index f2e5c3f..73e6074 100644 --- a/.cirrus.yml +++ b/.cirrus.yml @@ -1,17 +1,20 @@ env: CIRRUS_CLONE_SUBMODULES: true -# image: quay.io/pypa/manylinux_2_24_x86_64 # debian - +# quay.io/pypa/manylinux2014_x86_64 # CentOS 7 (use GCC 10) +# quay.io/pypa/manylinux_2_24_x86_64 # Debian (unknown) +# quay.io/pypa/manylinux_2_28_x86_64 # AlmaLinux (use clang) linux_task: container: # image: quay.io/pypa/manylinux_2_28_x86_64 image: quay.io/pypa/manylinux2014_x86_64 only_if: changesInclude('.cirrus.yml', '**.{h,cpp,py}') env: - CC: clang - CXX: clang++ - PATH: ${PATH}:/opt/python/cp310-cp310/bin + # CC: clang + # CXX: clang++ + CC: gcc + CXX: g++ + PATH: /opt/python/cp310-cp310/bin:${PATH} pip_cache: folder: ~/.cache/pip fingerprint_script: echo $PYTHON_VERSION From c79bd3759b2b84f976eb963a9a2f1e6d3dbae466 Mon Sep 17 00:00:00 2001 From: peekxc Date: Fri, 19 Jan 2024 14:49:00 -0500 Subject: [PATCH 4/5] CI changes --- pyproject.toml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 801bfe6..9ba1d92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,11 +44,15 @@ test-requires = ["pytest", "pytest-cov", "pytest-benchmark", "bokeh"] # coverage test-command = "python -m pytest {package}/tests/ --cov={package} --benchmark-skip" build-verbosity = 1 skip = "cp36-* pp* cp37-* *_ppc64le *_i686 *_s390x *-musllinux*" # todo: revisit musllinux -manylinux-x86_64-image = "manylinux_2_28" # prefer the newer one +# manylinux-x86_64-image = "manylinux_2_28" # prefer the newer one +manylinux-x86_64-image = "manylinux2014" + [tool.cibuildwheel.linux] before-build = "bash {project}/tools/cibw_linux.sh {project}" -environment = { CC="clang", CXX="clang++" } +environment = { CC="gcc", CXX="g++" } # for manylinux2014 +# environment = { CC="clang", CXX="clang++" } # for manylinux_2_28 + # before-build = ["ulimit -n 4096", "yum install -y clang", "yum install -y openblas"] [tool.cibuildwheel.macos] From 917e85f994c506fc42b15eca4bec0d7ab81318df Mon Sep 17 00:00:00 2001 From: peekxc Date: Fri, 19 Jan 2024 16:24:45 -0500 Subject: [PATCH 5/5] minor xtrace chnages for uniform API --- src/primate/trace.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/primate/trace.py b/src/primate/trace.py index d85dd77..ee81a3d 100644 --- a/src/primate/trace.py +++ b/src/primate/trace.py @@ -265,11 +265,19 @@ def hutchpp( """Hutch++ estimator. """ _operator_checks(A) - - ## Catch degenerate cases if (np.prod(A.shape) == 0) or (np.sum(A.shape) == 0): + ## Catch degenerate cases return 0 + ## If fun is specified, transparently convert A to matrix function + if isinstance(fun, str): + assert fun in _builtin_matrix_functions, "If given as a string, matrix_function be one of the builtin functions." + A = matrix_function(A, fun=fun) + elif isinstance(fun, Callable): + A = matrix_function(A, fun=fun) + elif fun is not None: + raise ValueError(f"Invalid matrix function type '{type(fun)}'") + ## Setup constants verbose, info = kwargs.get('verbose', False), kwargs.get('info', False) N: int = A.shape[0] @@ -366,8 +374,8 @@ def xtrace( fun: Union[str, Callable] = None, nv: Union[str, int] = "auto", pdf: str = "sphere", - atol: float = 0.1, - rtol: float = 1e-6, + atol: float = 0.0, + rtol: float = 0.0, cond_tol: float = 1e8, verbose: int = 0, info: bool = False, @@ -378,7 +386,7 @@ def xtrace( nv = int(nv) if isinstance(nv, Integral) else int(np.ceil(np.sqrt(A.shape[0]))) n = A.shape[0] - ## Transparently convert A to matrix function + ## If fun is specified, transparently convert A to matrix function if isinstance(fun, str): assert fun in _builtin_matrix_functions, "If given as a string, matrix_function be one of the builtin functions." A = matrix_function(A, fun=fun) @@ -413,6 +421,9 @@ def xtrace( if verbose > 0: print(f"It: {it}, est: {t:.8f}, Y_size: {Y.shape}, error: {err:.8f}") + if err <= atol: + break + if info: info = {"estimate": t, "samples": t_samples, "error": err } return t, info