Skip to content

Commit

Permalink
run black
Browse files Browse the repository at this point in the history
  • Loading branch information
PFLeget committed Sep 23, 2024
1 parent dc5719f commit 4719d5a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 17 deletions.
11 changes: 9 additions & 2 deletions python/lsst/meas/algorithms/gp_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@

import logging

__all__ = ["InterpolateOverDefectGaussianProcess", "GaussianProcessJax", "GaussianProcessTreegp"]
__all__ = [
"InterpolateOverDefectGaussianProcess",
"GaussianProcessJax",
"GaussianProcessTreegp",
]


def updateMaskFromArray(mask, bad_pixel, interpBit):
Expand Down Expand Up @@ -191,6 +195,7 @@ class GaussianProcessJax:
Mean value of the Gaussian Process. Default is 0.0.
"""

def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0):
self.std = std
self.correlation_length = correlation_length
Expand All @@ -217,7 +222,9 @@ def fit(self, x_train, y_train):
self._alpha = jax_get_alpha(y, kernel)

def predict(self, x_predict):
kernel_rect = jax_rbf_kernel(x_predict, self._x, self.std, self.correlation_length)
kernel_rect = jax_rbf_kernel(
x_predict, self._x, self.std, self.correlation_length
)
y_pred = jax_get_gp_predict(kernel_rect, self._alpha)
return y_pred + self.mean

Expand Down
40 changes: 25 additions & 15 deletions tests/test_gp_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
import lsst.utils.tests
import lsst.geom
import lsst.afw.image as afwImage
from lsst.meas.algorithms import InterpolateOverDefectGaussianProcess, GaussianProcessTreegp
from lsst.meas.algorithms import (
InterpolateOverDefectGaussianProcess,
GaussianProcessTreegp,
)


def rbf_kernel(x1, x2, sigma, correlation_length):
"""
Expand Down Expand Up @@ -62,7 +66,7 @@ def setUp(self):

npoints = 1000
self.std = 100
self.correlation_length = 10.
self.correlation_length = 10.0
self.white_noise = 1e-5

x1 = np.random.uniform(0, 99, npoints)
Expand All @@ -71,7 +75,7 @@ def setUp(self):

kernel = rbf_kernel(coord1, coord1, self.std, self.correlation_length)
kernel += np.eye(npoints) * self.white_noise**2

# Data augmentation. Create a gaussian random field
# on a 100 * 100 is to slow. So generate 1e3 points
# and then interpolate it with a GP to do data augmentation.
Expand All @@ -84,8 +88,12 @@ def setUp(self):
x2, x1 = np.meshgrid(x2, x1)
coord2 = np.array([x1.reshape(-1), x2.reshape(-1)]).T

tgp = GaussianProcessTreegp(std=self.std, correlation_length=self.correlation_length,
white_noise=self.white_noise, mean=0.0)
tgp = GaussianProcessTreegp(
std=self.std,
correlation_length=self.correlation_length,
white_noise=self.white_noise,
mean=0.0,
)
tgp.fit(coord1, z1)
z2 = tgp.predict(coord2)
z2 = z2.reshape(100, 121)
Expand Down Expand Up @@ -144,16 +152,18 @@ def test_interpolation(self, method: str):
Code used to solve gaussian process.
"""

gp = InterpolateOverDefectGaussianProcess(self.maskedimage,
defects=["BAD", "SAT", "CR", "EDGE"],
method=method,
fwhm=self.correlation_length,
bin_image=False,
bin_spacing=30,
threshold_dynamic_binning=1000,
threshold_subdivide=20000,
correlation_length_cut=5,
log=None,)
gp = InterpolateOverDefectGaussianProcess(
self.maskedimage,
defects=["BAD", "SAT", "CR", "EDGE"],
method=method,
fwhm=self.correlation_length,
bin_image=False,
bin_spacing=30,
threshold_dynamic_binning=1000,
threshold_subdivide=20000,
correlation_length_cut=5,
log=None,
)

gp.run()

Expand Down

0 comments on commit 4719d5a

Please sign in to comment.