From ac47215df83750b243a8f5ae28bda7b78fc52f3b Mon Sep 17 00:00:00 2001 From: Clare Saunders Date: Fri, 6 Sep 2024 13:08:18 -0700 Subject: [PATCH] Improve data masking and add utility function --- python/lsst/meas/algorithms/maskStreaks.py | 97 ++++++++++++++++++---- 1 file changed, 83 insertions(+), 14 deletions(-) diff --git a/python/lsst/meas/algorithms/maskStreaks.py b/python/lsst/meas/algorithms/maskStreaks.py index a7defb511..8427852e4 100644 --- a/python/lsst/meas/algorithms/maskStreaks.py +++ b/python/lsst/meas/algorithms/maskStreaks.py @@ -28,7 +28,9 @@ from skimage.feature import canny from sklearn.cluster import KMeans from dataclasses import dataclass +import astropy.units as u +from lsst.afw.geom import SpanSet import lsst.pex.config as pexConfig import lsst.pipe.base as pipeBase import lsst.kht @@ -45,6 +47,8 @@ class Line: rho: float theta: float sigma: float = 0 + chi2: float = 0 + modelMaximum: float = 0 class LineCollection: @@ -141,6 +145,49 @@ def __init__(self, data, weights, line=None): self._initLine = line self.setLineMask(line) + def getLineXY(self, line): + """Return the pixel coordinates of the ends of the line. + + Parameters + ---------- + line : `Line` + Line for which to fine the endpoints. + + Returns + ------- + boxIntersections : `np.ndarray` + (x, y) coordinates of the start and endpoints of the line. + """ + theta = line.theta * u.deg + # Get where the line intersects with each line making up bounding box. + # Bottom: + yA = -self._ymax / 2. + xA = (line.rho - yA * np.sin(theta)) / np.cos(theta) + # Left: + xB = -self._xmax / 2. + yB = (line.rho - xB * np.cos(theta)) / np.sin(theta) + # Top: + yC = self._ymax / 2. + xC = (line.rho - yC * np.sin(theta)) / np.cos(theta) + # Right: + xD = self._xmax / 2. + yD = (line.rho - xD * np.cos(theta)) / np.sin(theta) + + lineIntersections = np.array([[xA, yA], + [xB, yB], + [xC, yC], + [xD, yD]]) + lineIntersections[:, 0] += self._xmax / 2. + lineIntersections[:, 1] += self._ymax / 2. + + # The line will necessarily intersect with exactly two edges of the + # bounding box itself. + inBox = ((lineIntersections[:, 0] >= 0) & (lineIntersections[:, 0] <= self._xmax) + & (lineIntersections[:, 1] >= 0) & (lineIntersections[:, 1] <= self._ymax)) + boxIntersections = lineIntersections[inBox] + + return boxIntersections + def setLineMask(self, line): """Set mask around the image region near the line. @@ -300,10 +347,8 @@ def fit(self, dChi2Tol=0.1, maxIter=100, log=None): Returns ------- - outline : `np.ndarray` - Coordinates and inverse width of fit line. - chi2 : `float` - Reduced Chi2 of model fit to data. + outline : `Line` + Coordinates, inverse width, and chi2 of fit line. fitFailure : `bool` Boolean where `False` corresponds to a successful fit. """ @@ -348,9 +393,9 @@ def line_search(c, dx): oldChi2 = chi2 iter += 1 - outline = Line(x[0], x[1], abs(x[2])**-1) + outline = Line(x[0], x[1], abs(x[2])**-1, chi2) - return outline, chi2, fitFailure + return outline, fitFailure class MaskStreaksConfig(pexConfig.Config): @@ -425,11 +470,22 @@ class MaskStreaksConfig(pexConfig.Config): dtype=str, default="DETECTED" ) + onlyMaskDetected = pexConfig.Field( + doc=("If true, only propagate the part of the streak mask that " + "overlaps with the detection mask."), + dtype=bool, + default=True, + ) streaksMaskPlane = pexConfig.Field( doc="Name of mask plane holding detected streaks", dtype=str, default="STREAK" ) + badMaskPlanes = pexConfig.ListField( + doc="Names of mask planes to use when masking out regions.", + dtype=str, + default=("NO_DATA", "INTRP", "BAD", "SAT", "EDGE"), + ) class MaskStreaksTask(pipeBase.Task): @@ -479,7 +535,15 @@ def find(self, maskedImage): mask = maskedImage.mask detectionMask = (mask.array & mask.getPlaneBitMask(self.config.detectedMaskPlane)) - self.edges = self._cannyFilter(detectionMask) + initEdges = self._cannyFilter(detectionMask) + # Mask out edges of bad regions. + badPixelMask = mask.getPlaneBitMask(self.config.badMaskPlanes) + badMaskSpanSet = SpanSet.fromMask(mask, badPixelMask).split() + for sset in badMaskSpanSet: + sset_dilated = sset.dilated(1) + sset_dilated.clippedTo(mask.getBBox()).setMask(mask, mask.getPlaneBitMask("BAD")) + dilatedBadMask = (mask.array & badPixelMask) > 0 + self.edges = initEdges & ~dilatedBadMask self.lines = self._runKHT(self.edges) if len(self.lines) == 0: @@ -490,14 +554,15 @@ def find(self, maskedImage): clusters = self._findClusters(self.lines) fitLines, lineMask = self._fitProfile(clusters, maskedImage) - # The output mask is the intersection of the fit streaks and the image detections - outputMask = lineMask & detectionMask.astype(bool) + if self.config.onlyMaskDetected: + # The output mask is the intersection of the fit streaks and the image detections + lineMask &= detectionMask.astype(bool) return pipeBase.Struct( lines=fitLines, lineClusters=clusters, originalLines=self.lines, - mask=outputMask, + mask=lineMask, ) @timeMethod @@ -653,8 +718,12 @@ def _fitProfile(self, lines, maskedImage): """ data = maskedImage.image.array weights = maskedImage.variance.array**-1 + mask = maskedImage.mask + badPixelMask = mask.getPlaneBitMask(self.config.badMaskPlanes) + badMask = (mask.array & badPixelMask) > 0 # Mask out any pixels with non-finite weights weights[~np.isfinite(weights) | ~np.isfinite(data)] = 0 + weights[badMask] = 0 lineFits = LineCollection([], []) finalLineMasks = [np.zeros(data.shape, dtype=bool)] @@ -667,7 +736,7 @@ def _fitProfile(self, lines, maskedImage): if lineModel.lineMaskSize == 0: continue - fit, chi2, fitFailure = lineModel.fit(dChi2Tol=self.config.dChi2Tolerance, log=self.log) + fit, fitFailure = lineModel.fit(dChi2Tol=self.config.dChi2Tolerance, log=self.log) # Initial estimate should be quite close: fit is deemed unsuccessful if rho or theta # change more than the allowed bin in rho or theta: @@ -686,11 +755,11 @@ def _fitProfile(self, lines, maskedImage): # Take absolute value, as streaks are allowed to be negative finalModelMax = abs(finalModel).max() finalLineMask = abs(finalModel) > self.config.footprintThreshold - # Drop this line if the model profile is below the footprint threshold + # Drop this line if the model profile is below the footprint + # threshold if not finalLineMask.any(): continue - fit.chi2 = chi2 - fit.finalModelMax = finalModelMax + fit.modelMaximum = finalModelMax lineFits.append(fit) finalLineMasks.append(finalLineMask) nFinalLines += 1