Skip to content

Commit

Permalink
added comments for review
Browse files Browse the repository at this point in the history
  • Loading branch information
PFLeget committed Jun 13, 2024
1 parent 0698454 commit 7583b66
Showing 1 changed file with 20 additions and 177 deletions.
197 changes: 20 additions & 177 deletions python/lsst/meas/algorithms/gp_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,6 @@
__all__ = ["interpolateOverDefectsGP"]


# TO DO: REMOVE this helper function.
def get_name_pkl():
"""
Get name.
"""
import glob
import os
rep = '/sdf/home/l/leget/rubin-user/lsst_dev/tickets/DM-44305/data/'
all_pkl = glob.glob(os.path.join(rep, 'out_test_*'))
n_file = len(all_pkl)
file_name = f"out_test_{n_file}.pkl"
return os.path.join(rep, file_name)

@jit
def median_with_mad_clipping(data, mad_multiplier=2.0):

Expand Down Expand Up @@ -57,14 +44,6 @@ def jax_get_y_predict(HT, alpha):

@jit

Check failure on line 45 in python/lsst/meas/algorithms/gp_interpolation.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint

E302

expected 2 blank lines, found 1
def jax_rbf_k(x1, sigma, correlation_length, y_err):
"""Compute the RBF kernel with JAX.
:param x1: The first set of points. (n_samples,)
:param sigma: The amplitude of the kernel.
:param correlation_length: The correlation length of the kernel.
:param y_err: The error of the field. (n_samples)
:param white_noise: The white noise of the field.
"""

l1 = jax_pdist_squared(x1)
K = (sigma**2) * jnp.exp(-0.5 * l1 / (correlation_length**2))
Expand All @@ -74,120 +53,59 @@ def jax_rbf_k(x1, sigma, correlation_length, y_err):

@jit

Check failure on line 54 in python/lsst/meas/algorithms/gp_interpolation.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint

E302

expected 2 blank lines, found 1
def jax_rbf_h(x1, x2, sigma, correlation_length):
"""Compute the RBF kernel with JAX.
:param x1: The first set of points. (n_samples,)
:param x2: The second set of points. (n_samples)
:param sigma: The amplitude of the kernel.
:param correlation_length: The correlation length of the kernel.
:param y_err: The error of the field. (n_samples)
:param white_noise: The white noise of the field.
"""

l1 = jax_cdist_squared(x1, x2)
K = (sigma**2) * jnp.exp(-0.5 * l1 / (correlation_length**2))
return K

def updateMaskFromArray(mask, bad_pixel, interpBit):
x0 = mask.getX0()
y0 = mask.getY0()
for row in bad_pixel:
x = int(row[0] - x0)
y = int(row[1] - y0)
mask.array[y, x] |= interpBit

class GaussianProcessJax:
def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0):
"""
TO DO
"""

self.std = std
self.l = correlation_length

Check failure on line 66 in python/lsst/meas/algorithms/gp_interpolation.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint

E741

ambiguous variable name 'l'
self.white_noise = white_noise
self.mean = mean
self._alpha = None

def fit(self, x_good, y_good):
"""
Fits the Gaussian Process regression model to the given training data.

Args:
x_good (array-like): The input features of the training data.
y_good (array-like): The target values of the training data.
"""
y = y_good - self.mean
self._x = x_good
K = jax_rbf_k(x_good, self.std, self.l, self.white_noise)
self._alpha = jax_get_alpha(y, K)


def predict(self, x_bad):

Check failure on line 79 in python/lsst/meas/algorithms/gp_interpolation.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint

E303

too many blank lines (2)
"""
Makes predictions using the fitted Gaussian Process regression model.
Args:
x (array-like): The input features for which to make predictions.
Returns:
array-like: The predicted target values.

"""
HT = jax_rbf_h(x_bad, self._x, self.std, self.l)
y_pred = jax_get_y_predict(HT, self._alpha)
return y_pred + self.mean


def updateMaskFromArray(mask, bad_pixel, interpBit):

Check failure on line 85 in python/lsst/meas/algorithms/gp_interpolation.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint

E302

expected 2 blank lines, found 1
x0 = mask.getX0()
y0 = mask.getY0()
for row in bad_pixel:
x = int(row[0] - x0)
y = int(row[1] - y0)
mask.array[y, x] |= interpBit

# Vanilla Gaussian Process regression using treegp package
# There is no fancy O(N*log(N)) solver here, just the basic GP regression (Cholesky).
class GaussianProcessTreegp:
"""
Gaussian Process regression using treegp package.
This class implements Gaussian Process regression using the treegp package. It provides methods for fitting the
regression model and making predictions.
Attributes:
std (float): The standard deviation parameter for the Gaussian Process kernel.
l (float): The correlation length parameter for the Gaussian Process kernel.
white_noise (float): The white noise parameter for the Gaussian Process kernel.
mean (float): The mean parameter for the Gaussian Process kernel.
Methods:
fit(x_good, y_good): Fits the Gaussian Process regression model to the given training data.
predict(x_bad): Makes predictions using the fitted Gaussian Process regression model.
"""

def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0):
"""
Initializes a new instance of the gp_treegp class.

Args:
std (float, optional): The standard deviation parameter for the Gaussian Process kernel. Defaults to 2.
correlation_length (float, optional): The correlation length parameter for the Gaussian Process kernel.
Defaults to 1.
white_noise (float, optional): The white noise parameter for the Gaussian Process kernel. Defaults to 0.
mean (float, optional): The mean parameter for the Gaussian Process kernel. Defaults to 0.
"""
self.std = std
self.l = correlation_length
self.white_noise = white_noise
self.mean = mean

def fit(self, x_good, y_good):
"""
Fits the Gaussian Process regression model to the given training data.
Args:
x_good (array-like): The input features of the training data.
y_good (array-like): The target values of the training data.

"""
KERNEL = "%.2f**2 * RBF(%f)" % ((self.std, self.l))
kernel = f"{self.std}**2 * RBF({self.l})"
self.gp = treegp.GPInterpolation(
kernel=KERNEL,
kernel=kernel,
optimizer="none",
normalize=False,
white_noise=self.white_noise,
Expand All @@ -196,30 +114,12 @@ def fit(self, x_good, y_good):
self.gp.solve()

def predict(self, x_bad):
"""
Makes predictions using the fitted Gaussian Process regression model.
Args:
x (array-like): The input features for which to make predictions.

Returns:
array-like: The predicted target values.
"""
y_pred = self.gp.predict(x_bad)
return y_pred + self.mean


class InterpolateOverDefectGaussianProcess:
"""
Class for interpolating over defects in a masked image using Gaussian Processes.
Args:
maskedImage (MaskedImage): The masked image containing defects.
defects (list, optional): List of defect names to interpolate over. Defaults to ["SAT"].
fwhm (float, optional): FWHM from PSF and used as prior for correlation length. Defaults to 5.
bin_spacing (float, optional): Spacing for binning. Defaults to 10.
"""

def __init__(
self,
Expand All @@ -230,24 +130,15 @@ def __init__(
bin_spacing=10,
threshold_dynamic_binning=1000,
threshold_subdivide=20000,
correlation_length_cut=5,
):
"""
Initializes the InterpolateOverDefectGaussianProcess class.
Args:
maskedImage (MaskedImage): The masked image containing defects.
defects (list, optional): List of defect names to interpolate over. Defaults to ["SAT"].
fwhm (float, optional): FWHM from PSF and used as prior for correlation length. Defaults to 5.
bin_spacing (float, optional): Spacing for binning. Defaults to 10.
"""

if method not in ["jax", "treegp"]:
raise ValueError("Invalid method. Must be 'jax' or 'treegp'.")

if method == "jax":
self.GaussianProcess = GaussianProcessJax
if method == "treegp":
elif method == "treegp":
self.GaussianProcess = GaussianProcessTreegp
else:
raise ValueError("Invalid method. Must be 'jax' or 'treegp'.")

self.bin_spacing = bin_spacing
self.threshold_subdivide = threshold_subdivide
Expand All @@ -256,31 +147,27 @@ def __init__(
self.maskedImage = maskedImage
self.defects = defects
self.correlation_length = fwhm
self.correlation_length_cut = correlation_length_cut

self.interpBit = self.maskedImage.mask.getPlaneBitMask("INTRP")

def interpolate_over_defects(self):
"""
Interpolates over defects using the spanset method.
"""

mask = self.maskedImage.getMask()
# breakpoint()
badPixelMask = mask.getPlaneBitMask(self.defects)
badMaskSpanSet = SpanSet.fromMask(mask, badPixelMask).split()

bbox = self.maskedImage.getBBox()
glob_xmin, glob_xmax = bbox.minX, bbox.maxX
glob_ymin, glob_ymax = bbox.minY, bbox.maxY

for i in range(len(badMaskSpanSet)):
spanset = badMaskSpanSet[i]
for spanset in badMaskSpanSet:
bbox = spanset.getBBox()
# Dilate the bbox to make sure we have enough good pixels around the defect
# For now, we dilate by 5 times the correlation length
# For GP with isotropic kernel, points at 5 correlation lengths away have negligible
# effect on the prediction.
bbox = bbox.dilatedBy(int(self.correlation_length * 5)) # need integer as input.
bbox = bbox.dilatedBy(int(self.correlation_length * self.correlation_length_cut)) # need integer as input.
xmin, xmax = max([glob_xmin, bbox.minX]), min(glob_xmax, bbox.maxX)
ymin, ymax = max([glob_ymin, bbox.minY]), min(glob_ymax, bbox.maxY)
localBox = Box2I(Point2I(xmin, ymin), Extent2I(xmax - xmin, ymax - ymin))
Expand All @@ -295,16 +182,7 @@ def interpolate_over_defects(self):
self.maskedImage[localBox] = sub_masked_image

def _good_pixel_binning(self, good_pixel):
"""
Performs binning of good pixel data.

Parameters:
- good_pixel (numpy.ndarray): An array containing the good pixel data.
Returns:
- numpy.ndarray: An array containing the binned data.
"""
n_pixels = len(good_pixel[:,0])
dynamic_binning = int(np.sqrt(n_pixels / self.threshold_dynamic_binning))
if n_pixels/self.bin_spacing**2 < n_pixels/dynamic_binning**2:
Expand All @@ -318,17 +196,8 @@ def _good_pixel_binning(self, good_pixel):


def interpolate_sub_masked_image(self, sub_masked_image):
"""
Interpolates over defects in a sub-masked image.
Args:
sub_masked_image (MaskedImage): The sub-masked image containing defects.

Returns:
MaskedImage: The sub-masked image with defects interpolated.
"""

cut = int(self.correlation_length * 5) # need integer as input.
cut = int(self.correlation_length * self.correlation_length_cut) # need integer as input.
bad_pixel, good_pixel = ctUtils.findGoodPixelsAroundBadPixels(
sub_masked_image, self.defects, buffer=cut
)
Expand Down Expand Up @@ -383,33 +252,7 @@ def interpolate_sub_masked_image(self, sub_masked_image):

def interpolateOverDefectsGP(image, fwhm, badList, method="treegp", bin_spacing=25,
threshold_dynamic_binning=1000, threshold_subdivide=20000):
"""
Interpolates over defects in an image using Gaussian Process interpolation.
Args:
image : The input image.
fwhm (float): The full width at half maximum (FWHM) of the PSF used for approximation of correlation lenght.
badList (list): A list of defects to interpolate over.
bin_spacing (int, optional): The spacing between bins when resampling. Defaults to 15.
threshold_subdivide (int, optional): The threshold number of bad pixels to subdivide to avoid memory issue. Defaults to 20000.
Returns:
None
"""
#import pickle
#dic = {
# 'image': image,
# 'fwhm': fwhm,
# 'badList': badList,
# 'method': method,
# 'bin_spacing': bin_spacing,
# 'threshold_subdivide':threshold_subdivide,
# }
#file_name_out = get_name_pkl()
#fileout = open(file_name_out, 'wb')
#print(f'JE PASSE PAR LA | {file_name_out}')
#pickle.dump(dic, fileout)
#fileout.close()

if badList == [] or badList is None:
warnings.warn('WARNING: no defects found. No interpolation performed.')
return
Expand Down

0 comments on commit 7583b66

Please sign in to comment.