From 37f4e4cfd89019f172374c4f4623d850fed57fbf Mon Sep 17 00:00:00 2001 From: PFLeget Date: Mon, 20 May 2024 18:49:42 -0400 Subject: [PATCH] Save in memory Cholesky decomposition to avoid re-compute it. (#26) * saved cholesky output to avoid re-doing same computation * reset cashed cholesky value if re-run computation oh hyperparameters * fix typo * add comments from Clare --- treegp/gp_interp.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/treegp/gp_interp.py b/treegp/gp_interp.py index 62378fc..630db1e 100644 --- a/treegp/gp_interp.py +++ b/treegp/gp_interp.py @@ -106,6 +106,8 @@ def __init__( self._X0 = X0 self._y0 = y0 + self._alpha = None + def _fit(self, kernel, X, y, y_err): """Update the Kernel with data. @@ -114,6 +116,7 @@ def _fit(self, kernel, X, y, y_err): :param y: Values of the field. (n_samples) :param y_err: Error of y. (n_samples) """ + self._alpha = None if self.optimizer != "none": # Hyperparameters estimation using 2-point correlation # function information. @@ -172,11 +175,15 @@ def return_gp_predict(self, y, X1, X2, kernel, y_err, return_cov=False): :param y_err: Error of y. (n_samples) """ HT = kernel.__call__(X2, Y=X1) - K = kernel.__call__(X1) + np.eye(len(y)) * y_err**2 - factor = (cholesky(K, overwrite_a=True, lower=False), False) - alpha = cho_solve(factor, y, overwrite_b=False) - y_predict = np.dot(HT, alpha.reshape((len(alpha), 1))).T[0] + K = None + if self._alpha is None: + K = kernel.__call__(X1) + np.eye(len(y)) * y_err**2 + factor = (cholesky(K, overwrite_a=True, lower=False), False) + self._alpha = cho_solve(factor, y, overwrite_b=False) + y_predict = np.dot(HT, self._alpha.reshape((len(self._alpha), 1))).T[0] if return_cov: + if K is None: + K = kernel.__call__(X1) + np.eye(len(y)) * y_err**2 fact = cholesky( K, lower=True ) # I am computing maybe twice the same things... @@ -215,6 +222,9 @@ def initialize(self, X, y, y_err=None): self._mean = np.mean(y - self._spatial_average) else: self._mean = 0.0 + # Initialize alpha to None so that we know to recompute it if we change the + # input data. + self._alpha = None def _build_average_meanify(self, X): """Compute spatial average from meanify output for a given coordinate using KN interpolation.