Skip to content

Commit

Permalink
continue work on Clares comments
Browse files Browse the repository at this point in the history
  • Loading branch information
PFLeget committed Sep 17, 2024
1 parent 549da19 commit e55200a
Showing 1 changed file with 53 additions and 105 deletions.
158 changes: 53 additions & 105 deletions python/lsst/meas/algorithms/gp_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,16 @@ def median_with_mad_clipping(data, mad_multiplier=2.0):


@jit
def jax_rbf_kernel(x, sigma, correlation_length, y_err):
def jax_rbf_kernel(x1, x2, sigma, correlation_length, y_err):
"""
Computes the radial basis function (RBF) kernel matrix.
Parameters:
-----------
x : `np.array`
Input data points with shape (n_samples, n_features).
x1 : `np.array`
Location of training data point with shape (n_samples, n_features).
x2 : `np.array`
Location of training/test data point with shape (n_samples, n_features).
sigma : `float`
The scale parameter of the kernel.
correlation_length : `float`
Expand All @@ -93,39 +95,12 @@ def jax_rbf_kernel(x, sigma, correlation_length, y_err):
kernel : `np.array`
RBF kernel matrix with shape (n_samples, n_samples).
"""
distance_squared = jnp.sum((x[:, None, :] - x[None, :, :]) ** 2, axis=-1)
distance_squared = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1)
kernel = (sigma**2) * jnp.exp(-0.5 * distance_squared / (correlation_length**2))
y_err = jnp.ones(len(x[:, 0])) * y_err

Check failure on line 100 in python/lsst/meas/algorithms/gp_interpolation.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint

F821

undefined name 'x'
kernel += jnp.eye(len(y_err)) * (y_err**2)
return kernel


@jit
def jax_rbf_kernel_rect(x1, x2, sigma, correlation_length):
"""
Compute the radial basis function (RBF) kernel (rectangular matrix).
Parameters:
-----------
x1 : `np.array`
The first set of input points.
x2 : `np.array`
The second set of input points.
sigma : `float`
The scale parameter of the kernel.
correlation_length : float
The correlation length parameter of the kernel.
Returns:
--------
kernel_rect : `np.array`
The computed RBF kernel (rectangular matrix).
"""
l1 = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1)
kernel_rect = (sigma**2) * jnp.exp(-0.5 * l1 / (correlation_length**2))
return kernel_rect


@jit

Check failure on line 104 in python/lsst/meas/algorithms/gp_interpolation.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint

E302

expected 2 blank lines, found 1
def jax_get_alpha(y, kernel):
"""
Expand All @@ -141,7 +116,7 @@ def jax_get_alpha(y, kernel):
Returns:
--------
alpha : `np.array`
The alpha vector computed using the Cholesky decomposition and solve.
The alpha vector computed using the Cholesky decomposition and solution.
"""
factor = (jax.scipy.linalg.cholesky(kernel, overwrite_a=True, lower=False), False)
Expand All @@ -150,16 +125,16 @@ def jax_get_alpha(y, kernel):


@jit
def jax_get_y_predict(kernel_rect, alpha):
def jax_get_gp_predict(kernel_rect, alpha):
"""
Compute the predicted values of y using the given kernel and alpha.
Compute the predicted values of gp using the given kernel and alpha (cholesky solution).
Parameters:
-----------
kernel_rect : `np.array`
The kernel matrix.
alpha : `np.array`
The alpha vector.
The alpha vector from Cholesky solution.
Returns:
--------
Expand Down Expand Up @@ -204,24 +179,23 @@ def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0):
self.mean = mean
self._alpha = None

def fit(self, x_good, y_good):
y = y_good - self.mean
self._x = x_good
kernel = jax_rbf_kernel(x_good, self.std, self.correlation_length, self.white_noise)
def fit(self, x_train, y_train):
y = y_train - self.mean
self._x = x_train
kernel = jax_rbf_kernel(x_train, self.std, self.correlation_length, self.white_noise)
self._alpha = jax_get_alpha(y, kernel)

def predict(self, x_bad):
kernel_rect = jax_rbf_kernel_rect(x_bad, self._x, self.std, self.correlation_length)
y_pred = jax_get_y_predict(kernel_rect, self._alpha)
def predict(self, x_predict):
kernel_rect = jax_rbf_kernel(x_predict, self._x, self.std, self.correlation_length, 0)
y_pred = jax_get_gp_predict(kernel_rect, self._alpha)
return y_pred + self.mean


# Vanilla Gaussian Process regression using treegp package
# There is no fancy O(N*log(N)) solver here, just the basic GP regression (Cholesky).
class GaussianProcessTreegp:

Check failure on line 193 in python/lsst/meas/algorithms/gp_interpolation.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint

E302

expected 2 blank lines, found 1
"""
Gaussian Process Treegp class for Gaussian Process interpolation.
The basic GP regression, which uses Cholesky decomposition.
Parameters:
-----------
std : `float`, optional
Expand All @@ -240,15 +214,15 @@ def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0):
self.white_noise = white_noise
self.mean = mean

def fit(self, x_good, y_good):
def fit(self, x_train, y_train):
"""
Fit the Gaussian Process to the given training data.
Parameters:
-----------
x_good : `np.array`
x_train : `np.array`
Input features for the training data.
y_good : `np.array`
y_train : `np.array`
Target values for the training data.
"""
kernel = f"{self.std}**2 * RBF({self.correlation_length})"
Expand All @@ -258,24 +232,24 @@ def fit(self, x_good, y_good):
normalize=False,
white_noise=self.white_noise,
)
self.gp.initialize(x_good, y_good - self.mean)
self.gp.initialize(x_train, y_train - self.mean)
self.gp.solve()

def predict(self, x_bad):
def predict(self, x_predict):
"""
Predict the target values for the given input features.
Parameters:
-----------
x_bad : `np.array`
x_predict : `np.array`
Input features for the prediction.
Returns:
--------
y_pred : `np.array`
Predicted target values.
"""
y_pred = self.gp.predict(x_bad)
y_pred = self.gp.predict(x_predict)
return y_pred + self.mean


Expand Down Expand Up @@ -303,30 +277,6 @@ class InterpolateOverDefectGaussianProcess:
correlation_length_cut : `int`, optional
The factor by which to dilate the bounding box around defects. Default is 5.
Raises:
-------
ValueError
If an invalid method is provided.
Attributes:
-----------
bin_spacing : `int`
The spacing between bins for good pixel binning.
threshold_subdivide : `int`
The threshold for sub-dividing the bad pixel array to avoid memory error.
threshold_dynamic_binning : `int`
The threshold for dynamic binning.
maskedImage : `lsst.afw.image.MaskedImage`
The masked image containing the defects to be interpolated.
defects : `list`[`str`]
The types of defects to be interpolated.
correlation_length : `float`
The correlation length (FWHM).
correlation_length_cut : `int`
The factor by which to dilate the bounding box around defects.
interpBit : `int`
The bit mask for the "INTRP" plane in the image mask.
"""

def __init__(
Expand Down Expand Up @@ -360,28 +310,30 @@ def __init__(

def interpolate_over_defects(self):
"""
Interpolates over the defects in the image.
Interpolate over the defects in the image.
Change self.maskedImage .
"""

mask = self.maskedImage.getMask()
badPixelMask = mask.getPlaneBitMask(self.defects)
badMaskSpanSet = SpanSet.fromMask(mask, badPixelMask).split()

bbox = self.maskedImage.getBBox()
glob_xmin, glob_xmax = bbox.minX, bbox.maxX
glob_ymin, glob_ymax = bbox.minY, bbox.maxY
global_xmin, global_xmax = bbox.minX, bbox.maxX
global_ymin, global_ymax = bbox.minY, bbox.maxY

for spanset in badMaskSpanSet:
bbox = spanset.getBBox()
# Dilate the bbox to make sure we have enough good pixels around the defect
# For now, we dilate by 5 times the correlation length
# For GP with isotropic kernel, points at 5 correlation lengths away have negligible
# effect on the prediction.
# For GP with the isotropic kernel, points at the default value of
# correlation_length_cut=5 have negligible effect on the prediction.
bbox = bbox.dilatedBy(
int(self.correlation_length * self.correlation_length_cut)
) # need integer as input.
xmin, xmax = max([glob_xmin, bbox.minX]), min(glob_xmax, bbox.maxX)
ymin, ymax = max([glob_ymin, bbox.minY]), min(glob_ymax, bbox.maxY)
xmin, xmax = max([global_xmin, bbox.minX]), min(global_xmax, bbox.maxX)
ymin, ymax = max([global_ymin, bbox.minY]), min(global_ymax, bbox.maxY)
localBox = Box2I(Point2I(xmin, ymin), Extent2I(xmax - xmin, ymax - ymin))
try:
sub_masked_image = self.maskedImage[localBox]
Expand All @@ -391,31 +343,31 @@ def interpolate_over_defects(self):
sub_masked_image = self.interpolate_sub_masked_image(sub_masked_image)
self.maskedImage[localBox] = sub_masked_image

def _good_pixel_binning(self, good_pixel):
def _good_pixel_binning(self, pixels):
"""
Performs good pixel binning.
Performs pixel binning using treegp.meanify
Parameters:
-----------
good_pixel : `np.array`
The array of good pixels.
pixels : `np.array`
The array of pixels.
Returns:
--------
`np.array`
The binned array of good pixels.
The binned array of pixels.
"""

n_pixels = len(good_pixel[:, 0])
n_pixels = len(pixels[:, 0])
dynamic_binning = int(np.sqrt(n_pixels / self.threshold_dynamic_binning))
if n_pixels / self.bin_spacing**2 < n_pixels / dynamic_binning**2:
bin_spacing = self.bin_spacing
else:
bin_spacing = dynamic_binning
binning = treegp.meanify(bin_spacing=bin_spacing, statistics="mean")
binning.add_field(
good_pixel[:, :2],
good_pixel[:, 2:].T,
pixels[:, :2],
pixels[:, 2:].T,
)
binning.meanify()
return np.array(
Expand Down Expand Up @@ -454,7 +406,7 @@ def interpolate_sub_masked_image(self, sub_masked_image):
white_noise = np.sqrt(
np.mean(sub_image_array[np.isfinite(sub_image_array)])
)
kernel_amplitude = np.std(good_pixel[:, 2:])
kernel_amplitude = np.max(good_pixel[:, 2:])
if not np.isfinite(kernel_amplitude):
filter_finite = np.isfinite(good_pixel[:, 2:]).T[0]
good_pixel = good_pixel[filter_finite]
Expand All @@ -466,23 +418,21 @@ def interpolate_sub_masked_image(self, sub_masked_image):
# kernel amplitude might be better described by maximum value of good pixel given
# the data and not really a random gaussian field.
kernel_amplitude = np.max(good_pixel[:, 2:])
else:
kernel_amplitude = np.max(good_pixel[:, 2:])
try:
good_pixel = self._good_pixel_binning(copy.deepcopy(good_pixel))
except Exception:
warnings.warn(
"Binning failed, use original good pixel array in interpolate over."
"Binning failed, use original good pixel array in interpolation."
)

# put this after binning as comupting median is O(n*log(n))
mean = median_with_mad_clipping(good_pixel[:, 2:])
# put this after binning as computing median is O(n*log(n))
clipped_median = median_with_mad_clipping(good_pixel[:, 2:])

gp = self.GaussianProcess(
std=np.sqrt(kernel_amplitude),
correlation_length=self.correlation_length,
white_noise=white_noise,
mean=mean,
mean=clipped_median,
)
gp.fit(good_pixel[:, :2], np.squeeze(good_pixel[:, 2:]))
if bad_pixel.size < self.threshold_subdivide:
Expand All @@ -497,7 +447,7 @@ def interpolate_sub_masked_image(self, sub_masked_image):
np.shape(bad_pixel[i:end, 2:])
)

# update_value
# Update values
ctUtils.updateImageFromArray(sub_masked_image.image, bad_pixel)
updateMaskFromArray(sub_masked_image.mask, bad_pixel, self.interpBit)
return sub_masked_image
Expand All @@ -515,6 +465,10 @@ def interpolateOverDefectsGP(
"""
Interpolate over defects in an image using Gaussian Process interpolation.
This function performs Gaussian Process interpolation over defects in the input image.
It uses the provided defect coordinates to identify and interpolate over the defects.
The interpolated image is not returned, instead, the input image is modified in-place.
Parameters
----------
image : `np.array`
Expand All @@ -541,12 +495,6 @@ def interpolateOverDefectsGP(
UserWarning
If no defects are found in the image.
Notes
-----
This function performs Gaussian Process interpolation over defects in the input image.
It uses the provided defect coordinates to identify and interpolate over the defects.
The interpolated image is not returned, instead, the input image is modified in-place.
"""
if badList == [] or badList is None:
warnings.warn("WARNING: no defects found. No interpolation performed.")
Expand Down

0 comments on commit e55200a

Please sign in to comment.