Skip to content

Commit

Permalink
Improve data masking and add utility function
Browse files Browse the repository at this point in the history
  • Loading branch information
cmsaunders committed Sep 6, 2024
1 parent ae74b0b commit ac47215
Showing 1 changed file with 83 additions and 14 deletions.
97 changes: 83 additions & 14 deletions python/lsst/meas/algorithms/maskStreaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from skimage.feature import canny
from sklearn.cluster import KMeans
from dataclasses import dataclass
import astropy.units as u

from lsst.afw.geom import SpanSet
import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
import lsst.kht
Expand All @@ -45,6 +47,8 @@ class Line:
rho: float
theta: float
sigma: float = 0
chi2: float = 0
modelMaximum: float = 0


class LineCollection:
Expand Down Expand Up @@ -141,6 +145,49 @@ def __init__(self, data, weights, line=None):
self._initLine = line
self.setLineMask(line)

def getLineXY(self, line):
"""Return the pixel coordinates of the ends of the line.
Parameters
----------
line : `Line`
Line for which to fine the endpoints.
Returns
-------
boxIntersections : `np.ndarray`
(x, y) coordinates of the start and endpoints of the line.
"""
theta = line.theta * u.deg
# Get where the line intersects with each line making up bounding box.
# Bottom:
yA = -self._ymax / 2.
xA = (line.rho - yA * np.sin(theta)) / np.cos(theta)
# Left:
xB = -self._xmax / 2.
yB = (line.rho - xB * np.cos(theta)) / np.sin(theta)
# Top:
yC = self._ymax / 2.
xC = (line.rho - yC * np.sin(theta)) / np.cos(theta)
# Right:
xD = self._xmax / 2.
yD = (line.rho - xD * np.cos(theta)) / np.sin(theta)

lineIntersections = np.array([[xA, yA],
[xB, yB],
[xC, yC],
[xD, yD]])
lineIntersections[:, 0] += self._xmax / 2.
lineIntersections[:, 1] += self._ymax / 2.

# The line will necessarily intersect with exactly two edges of the
# bounding box itself.
inBox = ((lineIntersections[:, 0] >= 0) & (lineIntersections[:, 0] <= self._xmax)
& (lineIntersections[:, 1] >= 0) & (lineIntersections[:, 1] <= self._ymax))
boxIntersections = lineIntersections[inBox]

return boxIntersections

def setLineMask(self, line):
"""Set mask around the image region near the line.
Expand Down Expand Up @@ -300,10 +347,8 @@ def fit(self, dChi2Tol=0.1, maxIter=100, log=None):
Returns
-------
outline : `np.ndarray`
Coordinates and inverse width of fit line.
chi2 : `float`
Reduced Chi2 of model fit to data.
outline : `Line`
Coordinates, inverse width, and chi2 of fit line.
fitFailure : `bool`
Boolean where `False` corresponds to a successful fit.
"""
Expand Down Expand Up @@ -348,9 +393,9 @@ def line_search(c, dx):
oldChi2 = chi2
iter += 1

outline = Line(x[0], x[1], abs(x[2])**-1)
outline = Line(x[0], x[1], abs(x[2])**-1, chi2)

return outline, chi2, fitFailure
return outline, fitFailure


class MaskStreaksConfig(pexConfig.Config):
Expand Down Expand Up @@ -425,11 +470,22 @@ class MaskStreaksConfig(pexConfig.Config):
dtype=str,
default="DETECTED"
)
onlyMaskDetected = pexConfig.Field(
doc=("If true, only propagate the part of the streak mask that "
"overlaps with the detection mask."),
dtype=bool,
default=True,
)
streaksMaskPlane = pexConfig.Field(
doc="Name of mask plane holding detected streaks",
dtype=str,
default="STREAK"
)
badMaskPlanes = pexConfig.ListField(
doc="Names of mask planes to use when masking out regions.",
dtype=str,
default=("NO_DATA", "INTRP", "BAD", "SAT", "EDGE"),
)


class MaskStreaksTask(pipeBase.Task):
Expand Down Expand Up @@ -479,7 +535,15 @@ def find(self, maskedImage):
mask = maskedImage.mask
detectionMask = (mask.array & mask.getPlaneBitMask(self.config.detectedMaskPlane))

self.edges = self._cannyFilter(detectionMask)
initEdges = self._cannyFilter(detectionMask)
# Mask out edges of bad regions.
badPixelMask = mask.getPlaneBitMask(self.config.badMaskPlanes)
badMaskSpanSet = SpanSet.fromMask(mask, badPixelMask).split()
for sset in badMaskSpanSet:
sset_dilated = sset.dilated(1)
sset_dilated.clippedTo(mask.getBBox()).setMask(mask, mask.getPlaneBitMask("BAD"))
dilatedBadMask = (mask.array & badPixelMask) > 0
self.edges = initEdges & ~dilatedBadMask
self.lines = self._runKHT(self.edges)

if len(self.lines) == 0:
Expand All @@ -490,14 +554,15 @@ def find(self, maskedImage):
clusters = self._findClusters(self.lines)
fitLines, lineMask = self._fitProfile(clusters, maskedImage)

# The output mask is the intersection of the fit streaks and the image detections
outputMask = lineMask & detectionMask.astype(bool)
if self.config.onlyMaskDetected:
# The output mask is the intersection of the fit streaks and the image detections
lineMask &= detectionMask.astype(bool)

return pipeBase.Struct(
lines=fitLines,
lineClusters=clusters,
originalLines=self.lines,
mask=outputMask,
mask=lineMask,
)

@timeMethod
Expand Down Expand Up @@ -653,8 +718,12 @@ def _fitProfile(self, lines, maskedImage):
"""
data = maskedImage.image.array
weights = maskedImage.variance.array**-1
mask = maskedImage.mask
badPixelMask = mask.getPlaneBitMask(self.config.badMaskPlanes)
badMask = (mask.array & badPixelMask) > 0
# Mask out any pixels with non-finite weights
weights[~np.isfinite(weights) | ~np.isfinite(data)] = 0
weights[badMask] = 0

lineFits = LineCollection([], [])
finalLineMasks = [np.zeros(data.shape, dtype=bool)]
Expand All @@ -667,7 +736,7 @@ def _fitProfile(self, lines, maskedImage):
if lineModel.lineMaskSize == 0:
continue

fit, chi2, fitFailure = lineModel.fit(dChi2Tol=self.config.dChi2Tolerance, log=self.log)
fit, fitFailure = lineModel.fit(dChi2Tol=self.config.dChi2Tolerance, log=self.log)

# Initial estimate should be quite close: fit is deemed unsuccessful if rho or theta
# change more than the allowed bin in rho or theta:
Expand All @@ -686,11 +755,11 @@ def _fitProfile(self, lines, maskedImage):
# Take absolute value, as streaks are allowed to be negative
finalModelMax = abs(finalModel).max()
finalLineMask = abs(finalModel) > self.config.footprintThreshold
# Drop this line if the model profile is below the footprint threshold
# Drop this line if the model profile is below the footprint
# threshold
if not finalLineMask.any():
continue
fit.chi2 = chi2
fit.finalModelMax = finalModelMax
fit.modelMaximum = finalModelMax
lineFits.append(fit)
finalLineMasks.append(finalLineMask)
nFinalLines += 1
Expand Down

0 comments on commit ac47215

Please sign in to comment.