Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-42087: Improve data masking and add utility function #386

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 89 additions & 15 deletions python/lsst/meas/algorithms/maskStreaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from skimage.feature import canny
from sklearn.cluster import KMeans
from dataclasses import dataclass
import astropy.units as u

from lsst.afw.geom import SpanSet
import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
import lsst.kht
Expand All @@ -39,12 +41,16 @@
class Line:
"""A simple data class to describe a line profile. The parameter `rho`
describes the distance from the center of the image, `theta` describes
the angle, and `sigma` describes the width of the line.
the angle, `sigma` describes the width of the line, `reducedChi2` gives
the reduced chi2 of the fit, and `modelMaximum` gives the peak value of the
fit line profile.
"""

rho: float
theta: float
sigma: float = 0
reducedChi2: float = 0
modelMaximum: float = 0


class LineCollection:
Expand Down Expand Up @@ -141,6 +147,50 @@ def __init__(self, data, weights, line=None):
self._initLine = line
self.setLineMask(line)

def getLineXY(self, line):
"""Return the pixel coordinates of the ends of the line.

Parameters
----------
line : `Line`
Line for which to find the endpoints.

Returns
-------
boxIntersections : `np.ndarray`
(x, y) coordinates of the start and endpoints of the line.
"""
theta = line.theta * u.deg
# Determine where the line intersects with each edge of the bounding
# box.
# Bottom:
yA = -self._ymax / 2.
xA = (line.rho - yA * np.sin(theta)) / np.cos(theta)
# Left:
xB = -self._xmax / 2.
yB = (line.rho - xB * np.cos(theta)) / np.sin(theta)
# Top:
yC = self._ymax / 2.
xC = (line.rho - yC * np.sin(theta)) / np.cos(theta)
# Right:
xD = self._xmax / 2.
yD = (line.rho - xD * np.cos(theta)) / np.sin(theta)

lineIntersections = np.array([[xA, yA],
[xB, yB],
[xC, yC],
[xD, yD]])
lineIntersections[:, 0] += self._xmax / 2.
lineIntersections[:, 1] += self._ymax / 2.

# The line will necessarily intersect with exactly two edges of the
# bounding box itself.
inBox = ((lineIntersections[:, 0] >= 0) & (lineIntersections[:, 0] <= self._xmax)
& (lineIntersections[:, 1] >= 0) & (lineIntersections[:, 1] <= self._ymax))
boxIntersections = lineIntersections[inBox]

return boxIntersections

def setLineMask(self, line):
"""Set mask around the image region near the line.

Expand Down Expand Up @@ -300,10 +350,8 @@ def fit(self, dChi2Tol=0.1, maxIter=100, log=None):

Returns
-------
outline : `np.ndarray`
Coordinates and inverse width of fit line.
chi2 : `float`
Reduced Chi2 of model fit to data.
outline : `Line`
Coordinates, inverse width, and chi2 of fit line.
fitFailure : `bool`
Boolean where `False` corresponds to a successful fit.
"""
Expand Down Expand Up @@ -348,9 +396,9 @@ def line_search(c, dx):
oldChi2 = chi2
iter += 1

outline = Line(x[0], x[1], abs(x[2])**-1)
outline = Line(x[0], x[1], abs(x[2])**-1, chi2)

return outline, chi2, fitFailure
return outline, fitFailure


class MaskStreaksConfig(pexConfig.Config):
Expand Down Expand Up @@ -425,11 +473,23 @@ class MaskStreaksConfig(pexConfig.Config):
dtype=str,
default="DETECTED"
)
onlyMaskDetected = pexConfig.Field(
doc=("If true, only propagate the part of the streak mask that "
"overlaps with the detection mask."),
dtype=bool,
default=True,
)
streaksMaskPlane = pexConfig.Field(
doc="Name of mask plane holding detected streaks",
dtype=str,
default="STREAK"
)
badMaskPlanes = pexConfig.ListField(
doc=("Names of mask plane regions to ignore entirely when doing streak"
" detection"),
dtype=str,
default=("NO_DATA", "INTRP", "BAD", "SAT", "EDGE"),
)


class MaskStreaksTask(pipeBase.Task):
Expand Down Expand Up @@ -479,7 +539,16 @@ def find(self, maskedImage):
mask = maskedImage.mask
detectionMask = (mask.array & mask.getPlaneBitMask(self.config.detectedMaskPlane))

self.edges = self._cannyFilter(detectionMask)
initEdges = self._cannyFilter(detectionMask)
# Ignore regions with known bad masks, adding a one-pixel buffer around
# each to ensure that the edges of bad regions are also ignored.
badPixelMask = mask.getPlaneBitMask(self.config.badMaskPlanes)
badMaskSpanSet = SpanSet.fromMask(mask, badPixelMask).split()
for sset in badMaskSpanSet:
sset_dilated = sset.dilated(1)
sset_dilated.clippedTo(mask.getBBox()).setMask(mask, mask.getPlaneBitMask("BAD"))
dilatedBadMask = (mask.array & badPixelMask) > 0
self.edges = initEdges & ~dilatedBadMask
self.lines = self._runKHT(self.edges)

if len(self.lines) == 0:
Expand All @@ -490,14 +559,15 @@ def find(self, maskedImage):
clusters = self._findClusters(self.lines)
fitLines, lineMask = self._fitProfile(clusters, maskedImage)

# The output mask is the intersection of the fit streaks and the image detections
outputMask = lineMask & detectionMask.astype(bool)
if self.config.onlyMaskDetected:
# The output mask is the intersection of the fit streaks and the image detections
lineMask &= detectionMask.astype(bool)

return pipeBase.Struct(
lines=fitLines,
lineClusters=clusters,
originalLines=self.lines,
mask=outputMask,
mask=lineMask,
)

@timeMethod
Expand Down Expand Up @@ -653,8 +723,12 @@ def _fitProfile(self, lines, maskedImage):
"""
data = maskedImage.image.array
weights = maskedImage.variance.array**-1
mask = maskedImage.mask
badPixelMask = mask.getPlaneBitMask(self.config.badMaskPlanes)
badMask = (mask.array & badPixelMask) > 0
# Mask out any pixels with non-finite weights
weights[~np.isfinite(weights) | ~np.isfinite(data)] = 0
weights[badMask] = 0

lineFits = LineCollection([], [])
finalLineMasks = [np.zeros(data.shape, dtype=bool)]
Expand All @@ -667,7 +741,7 @@ def _fitProfile(self, lines, maskedImage):
if lineModel.lineMaskSize == 0:
continue

fit, chi2, fitFailure = lineModel.fit(dChi2Tol=self.config.dChi2Tolerance, log=self.log)
fit, fitFailure = lineModel.fit(dChi2Tol=self.config.dChi2Tolerance, log=self.log)

# Initial estimate should be quite close: fit is deemed unsuccessful if rho or theta
# change more than the allowed bin in rho or theta:
Expand All @@ -686,11 +760,11 @@ def _fitProfile(self, lines, maskedImage):
# Take absolute value, as streaks are allowed to be negative
finalModelMax = abs(finalModel).max()
finalLineMask = abs(finalModel) > self.config.footprintThreshold
# Drop this line if the model profile is below the footprint threshold
# Drop this line if the model profile is below the footprint
# threshold
if not finalLineMask.any():
continue
fit.chi2 = chi2
fit.finalModelMax = finalModelMax
fit.modelMaximum = finalModelMax
lineFits.append(fit)
finalLineMasks.append(finalLineMask)
nFinalLines += 1
Expand Down
Loading