From 0431cfbe0ab128b2a7bc03b910254f4c24ccd919 Mon Sep 17 00:00:00 2001 From: Lee Kelvin Date: Tue, 26 Mar 2024 05:46:55 -0700 Subject: [PATCH] Construct new storage architech, non-backwd compat --- .../lsst/meas/algorithms/brightStarStamps.py | 674 +++--------------- python/lsst/meas/algorithms/stamps.py | 430 +++-------- 2 files changed, 204 insertions(+), 900 deletions(-) diff --git a/python/lsst/meas/algorithms/brightStarStamps.py b/python/lsst/meas/algorithms/brightStarStamps.py index 875bfbbca..ba921e3f1 100644 --- a/python/lsst/meas/algorithms/brightStarStamps.py +++ b/python/lsst/meas/algorithms/brightStarStamps.py @@ -19,449 +19,131 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from __future__ import annotations + """Collection of small images (stamps), each centered on a bright star.""" __all__ = ["BrightStarStamp", "BrightStarStamps"] import logging -from collections.abc import Collection, Mapping +from collections.abc import Mapping from dataclasses import dataclass -from functools import reduce -from operator import ior -import numpy as np -from lsst.afw.geom import SpanSet, Stencil +from lsst.afw.detection import Psf +from lsst.afw.geom import SkyWcs from lsst.afw.image import MaskedImageF -from lsst.afw.math import Property, StatisticsControl, makeStatistics, stringToStatisticsProperty from lsst.afw.table.io import Persistable from lsst.daf.base import PropertyList -from lsst.geom import Point2I +from lsst.geom import Point2D -from .stamps import StampBase, Stamps, readFitsWithOptions +from .stamps import StampBase, StampsBase, readFitsWithOptions logger = logging.getLogger(__name__) @dataclass class BrightStarStamp(StampBase): - """Single stamp centered on a bright star, normalized by its annularFlux. - - Parameters - ---------- - stamp_im : `~lsst.afw.image.MaskedImage` - Pixel data for this postage stamp - gaiaGMag : `float` - Gaia G magnitude for the object in this stamp - gaiaId : `int` - Gaia object identifier - position : `~lsst.geom.Point2I` - Origin of the stamps in its origin exposure (pixels) - archive_elements : `~collections.abc.Mapping`[ `str` , \ - `~lsst.afw.table.io.Persistable`], optional - Archive elements (e.g. Transform / WCS) associated with this stamp. - annularFlux : `float` or None, optional - Flux in an annulus around the object - """ - - stamp_im: MaskedImageF - gaiaGMag: float - gaiaId: int - position: Point2I - archive_elements: Mapping[str, Persistable] | None = None - annularFlux: float | None = None - minValidAnnulusFraction: float = 0.0 - validAnnulusFraction: float | None = None - optimalInnerRadius: int | None = None - optimalOuterRadius: int | None = None + """Single stamp centered on a bright star.""" + + maskedImage: MaskedImageF + psf: Psf + wcs: SkyWcs + visit: int + detector: int + refId: int + refMag: float + position: Point2D + scale: float | None + scaleErr: float | None + pedestal: float | None + pedestalErr: float | None + pedestalScaleCov: float | None + xGradient: float | None + yGradient: float | None + globalReducedChiSquared: float | None + globalDegreesOfFreedom: int | None + psfReducedChiSquared: float | None + psfDegreesOfFreedom: int | None + psfMaskedFluxFrac: float | None @classmethod - def factory(cls, stamp_im, metadata, idx, archive_elements=None, minValidAnnulusFraction=0.0): - """This method is needed to service the FITS reader. We need a standard - interface to construct objects like this. Parameters needed to - construct this object are passed in via a metadata dictionary and then - passed to the constructor of this class. This particular factory - method requires keys: G_MAGS, GAIA_IDS, and ANNULAR_FLUXES. They should - each point to lists of values. + def _getMaskedImageClass(cls) -> type[MaskedImageF]: + return MaskedImageF - Parameters - ---------- - stamp_im : `~lsst.afw.image.MaskedImage` - Pixel data to pass to the constructor - metadata : `dict` - Dictionary containing the information - needed by the constructor. - idx : `int` - Index into the lists in ``metadata`` - archive_elements : `~collections.abc.Mapping`[ `str` , \ - `~lsst.afw.table.io.Persistable`], optional - Archive elements (e.g. Transform / WCS) associated with this stamp. - minValidAnnulusFraction : `float`, optional - The fraction of valid pixels within the normalization annulus of a - star. + @classmethod + def _getArchiveElementNames(cls) -> list[str]: + return ["PSF", "WCS"] - Returns - ------- - brightstarstamp : `BrightStarStamp` - An instance of this class - """ - if "X0" in metadata and "Y0" in metadata: - x0 = metadata["X0"] - y0 = metadata["X0"] - position = Point2I(x0, y0) - else: - position = None + @classmethod + def factory( + cls, + maskedImage: MaskedImageF, + metadata: PropertyList, + idx: int, + archive_elements: Mapping[str, Persistable] | None = None, + ): + assert archive_elements is not None return cls( - stamp_im=stamp_im, - gaiaGMag=metadata["G_MAGS"], - gaiaId=metadata["GAIA_IDS"], - position=position, - archive_elements=archive_elements, - annularFlux=metadata["ANNULAR_FLUXES"], - minValidAnnulusFraction=minValidAnnulusFraction, - validAnnulusFraction=metadata["VALID_PIXELS_FRACTION"], + maskedImage=maskedImage, + psf=archive_elements["PSF"], + wcs=archive_elements["WCS"], + visit=metadata["VISIT"], + detector=metadata["DETECTOR"], + refId=metadata["REFID"], + refMag=metadata["REFMAG"], + position=Point2D(metadata["X_FA"], metadata["Y_FA"]), + scale=metadata["SCALE"], + scaleErr=metadata["SCALE_ERR"], + pedestal=metadata["PEDESTAL"], + pedestalErr=metadata["PEDESTAL_ERR"], + pedestalScaleCov=metadata["PEDESTAL_SCALE_COV"], + xGradient=metadata["X_GRADIENT"], + yGradient=metadata["Y_GRADIENT"], + globalReducedChiSquared=metadata["GLOBAL_REDUCED_CHI_SQUARED"], + globalDegreesOfFreedom=metadata["GLOBAL_DEGREES_OF_FREEDOM"], + psfReducedChiSquared=metadata["PSF_REDUCED_CHI_SQUARED"], + psfDegreesOfFreedom=metadata["PSF_DEGREES_OF_FREEDOM"], + psfMaskedFluxFrac=metadata["PSF_MASKED_FLUX_FRAC"], ) def _getMaskedImage(self): - return self.stamp_im + return self.maskedImage def _getArchiveElements(self): - return self.archive_elements + return {"PSF": self.psf, "WCS": self.wcs} def _getMetadata(self) -> PropertyList | None: md = PropertyList() - md["G_MAG"] = self.gaiaGMag - md["GAIA_ID"] = self.gaiaId - md["X0"] = self.position.x - md["Y0"] = self.position.y - md["ANNULAR_FLUX"] = self.annularFlux - md["VALID_PIXELS_FRACTION"] = self.validAnnulusFraction + md["VISIT"] = self.visit + md["DETECTOR"] = self.detector + md["REFID"] = self.refId + md["REFMAG"] = self.refMag + md["X_FA"] = self.position.x + md["Y_FA"] = self.position.y + md["SCALE"] = self.scale + md["SCALE_ERR"] = self.scaleErr + md["PEDESTAL"] = self.pedestal + md["PEDESTAL_ERR"] = self.pedestalErr + md["PEDESTAL_SCALE_COV"] = self.pedestalScaleCov + md["X_GRADIENT"] = self.xGradient + md["Y_GRADIENT"] = self.yGradient + md["GLOBAL_REDUCED_CHI_SQUARED"] = self.globalReducedChiSquared + md["GLOBAL_DEGREES_OF_FREEDOM"] = self.globalDegreesOfFreedom + md["PSF_REDUCED_CHI_SQUARED"] = self.psfReducedChiSquared + md["PSF_DEGREES_OF_FREEDOM"] = self.psfDegreesOfFreedom + md["PSF_MASKED_FLUX_FRAC"] = self.psfMaskedFluxFrac return md - def measureAndNormalize( - self, - annulus: SpanSet, - statsControl: StatisticsControl = StatisticsControl(), - statsFlag: Property = stringToStatisticsProperty("MEAN"), - badMaskPlanes: Collection[str] = ("BAD", "SAT", "NO_DATA"), - ): - """Compute "annularFlux", the integrated flux within an annulus - around an object's center, and normalize it. - - Since the center of bright stars are saturated and/or heavily affected - by ghosts, we measure their flux in an annulus with a large enough - inner radius to avoid the most severe ghosts and contain enough - non-saturated pixels. - - Parameters - ---------- - annulus : `~lsst.afw.geom.spanSet.SpanSet` - SpanSet containing the annulus to use for normalization. - statsControl : `~lsst.afw.math.statistics.StatisticsControl`, optional - StatisticsControl to be used when computing flux over all pixels - within the annulus. - statsFlag : `~lsst.afw.math.statistics.Property`, optional - statsFlag to be passed on to ``afwMath.makeStatistics`` to compute - annularFlux. Defaults to a simple MEAN. - badMaskPlanes : `collections.abc.Collection` [`str`] - Collection of mask planes to ignore when computing annularFlux. - """ - stampSize = self.stamp_im.getDimensions() - # Create image: science pixel values within annulus, NO_DATA elsewhere - maskPlaneDict = self.stamp_im.mask.getMaskPlaneDict() - annulusImage = MaskedImageF(stampSize, planeDict=maskPlaneDict) - annulusMask = annulusImage.mask - annulusMask.array[:] = 2 ** maskPlaneDict["NO_DATA"] - annulus.copyMaskedImage(self.stamp_im, annulusImage) - # Set mask planes to be ignored. - andMask = reduce(ior, (annulusMask.getPlaneBitMask(bm) for bm in badMaskPlanes)) - statsControl.setAndMask(andMask) - - annulusStat = makeStatistics(annulusImage, statsFlag, statsControl) - # Determine the number of valid (unmasked) pixels within the annulus. - unMasked = annulusMask.array.size - np.count_nonzero(annulusMask.array) - self.validAnnulusFraction = unMasked / annulus.getArea() - logger.info( - "The Star's annulus contains %s valid pixels and the annulus itself contains %s pixels.", - unMasked, - annulus.getArea(), - ) - if unMasked > (annulus.getArea() * self.minValidAnnulusFraction): - # Compute annularFlux. - self.annularFlux = annulusStat.getValue() - logger.info("Annular flux is: %s", self.annularFlux) - else: - raise RuntimeError( - f"Less than {self.minValidAnnulusFraction * 100}% of pixels within the annulus are valid." - ) - if np.isnan(self.annularFlux): - raise RuntimeError("Annular flux computation failed, likely because there are no valid pixels.") - if self.annularFlux < 0: - raise RuntimeError("The annular flux is negative. The stamp can not be normalized!") - # Normalize stamps. - self.stamp_im.image.array /= self.annularFlux - return None - -class BrightStarStamps(Stamps): - """Collection of bright star stamps and associated metadata. - - Parameters - ---------- - starStamps : `collections.abc.Sequence` [`BrightStarStamp`] - Sequence of star stamps. Cannot contain both normalized and - unnormalized stamps. - innerRadius : `int`, optional - Inner radius value, in pixels. This and ``outerRadius`` define the - annulus used to compute the ``"annularFlux"`` values within each - ``starStamp``. Must be provided if ``normalize`` is True. - outerRadius : `int`, optional - Outer radius value, in pixels. This and ``innerRadius`` define the - annulus used to compute the ``"annularFlux"`` values within each - ``starStamp``. Must be provided if ``normalize`` is True. - nb90Rots : `int`, optional - Number of 90 degree rotations required to compensate for detector - orientation. - metadata : `~lsst.daf.base.PropertyList`, optional - Metadata associated with the bright stars. - use_mask : `bool` - If `True` read and write mask data. Default `True`. - use_variance : `bool` - If ``True`` read and write variance data. Default ``False``. - use_archive : `bool` - If ``True`` read and write an Archive that contains a Persistable - associated with each stamp. In the case of bright stars, this is - usually a ``TransformPoint2ToPoint2``, used to warp each stamp - to the same pixel grid before stacking. - - Raises - ------ - ValueError - Raised if one of the star stamps provided does not contain the - required keys. - AttributeError - Raised if there is a mix-and-match of normalized and unnormalized - stamps, stamps normalized with different annulus definitions, or if - stamps are to be normalized but annular radii were not provided. - - Notes - ----- - A butler can be used to read only a part of the stamps, specified by a - bbox: - - >>> starSubregions = butler.get( - "brightStarStamps", - dataId, - parameters={"bbox": bbox} - ) - """ +class BrightStarStamps(StampsBase): def __init__( self, starStamps, - innerRadius=None, - outerRadius=None, - nb90Rots=None, metadata=None, - use_mask=True, - use_variance=False, - use_archive=False, ): - super().__init__(starStamps, metadata, use_mask, use_variance, use_archive) - # From v2 onwards, stamps are now always assumed to be unnormalized - self.normalized = False - self.nb90Rots = nb90Rots - - @classmethod - def initAndNormalize( - cls, - starStamps, - innerRadius, - outerRadius, - nb90Rots=None, - metadata=None, - use_mask=True, - use_variance=False, - use_archive=False, - imCenter=None, - discardNanFluxObjects=True, - forceFindFlux=False, - statsControl=StatisticsControl(), - statsFlag=stringToStatisticsProperty("MEAN"), - badMaskPlanes=("BAD", "SAT", "NO_DATA"), - ): - """Normalize a set of bright star stamps and initialize a - BrightStarStamps instance. - - Since the center of bright stars are saturated and/or heavily affected - by ghosts, we measure their flux in an annulus with a large enough - inner radius to avoid the most severe ghosts and contain enough - non-saturated pixels. - - Parameters - ---------- - starStamps : `collections.abc.Sequence` [`BrightStarStamp`] - Sequence of star stamps. Cannot contain both normalized and - unnormalized stamps. - innerRadius : `int` - Inner radius value, in pixels. This and ``outerRadius`` define the - annulus used to compute the ``"annularFlux"`` values within each - ``starStamp``. - outerRadius : `int` - Outer radius value, in pixels. This and ``innerRadius`` define the - annulus used to compute the ``"annularFlux"`` values within each - ``starStamp``. - nb90Rots : `int`, optional - Number of 90 degree rotations required to compensate for detector - orientation. - metadata : `~lsst.daf.base.PropertyList`, optional - Metadata associated with the bright stars. - use_mask : `bool` - If `True` read and write mask data. Default `True`. - use_variance : `bool` - If ``True`` read and write variance data. Default ``False``. - use_archive : `bool` - If ``True`` read and write an Archive that contains a Persistable - associated with each stamp. In the case of bright stars, this is - usually a ``TransformPoint2ToPoint2``, used to warp each stamp - to the same pixel grid before stacking. - imCenter : `collections.abc.Sequence`, optional - Center of the object, in pixels. If not provided, the center of the - first stamp's pixel grid will be used. - discardNanFluxObjects : `bool` - Whether objects with NaN annular flux should be discarded. - If False, these objects will not be normalized. - forceFindFlux : `bool` - Whether to try to find the flux of objects with NaN annular flux - at a different annulus. - statsControl : `~lsst.afw.math.statistics.StatisticsControl`, optional - StatisticsControl to be used when computing flux over all pixels - within the annulus. - statsFlag : `~lsst.afw.math.statistics.Property`, optional - statsFlag to be passed on to ``~lsst.afw.math.makeStatistics`` to - compute annularFlux. Defaults to a simple MEAN. - badMaskPlanes : `collections.abc.Collection` [`str`] - Collection of mask planes to ignore when computing annularFlux. - - Raises - ------ - ValueError - Raised if one of the star stamps provided does not contain the - required keys. - AttributeError - Raised if there is a mix-and-match of normalized and unnormalized - stamps, stamps normalized with different annulus definitions, or if - stamps are to be normalized but annular radii were not provided. - """ - stampSize = starStamps[0].stamp_im.getDimensions() - if imCenter is None: - imCenter = stampSize[0] // 2, stampSize[1] // 2 - - # Create SpanSet of annulus. - outerCircle = SpanSet.fromShape(outerRadius, Stencil.CIRCLE, offset=imCenter) - innerCircle = SpanSet.fromShape(innerRadius, Stencil.CIRCLE, offset=imCenter) - annulusWidth = outerRadius - innerRadius - if annulusWidth < 1: - raise ValueError("The annulus width must be greater than 1 pixel.") - annulus = outerCircle.intersectNot(innerCircle) - - # Initialize (unnormalized) brightStarStamps instance. - bss = cls( - starStamps, - innerRadius=None, - outerRadius=None, - nb90Rots=nb90Rots, - metadata=metadata, - use_mask=use_mask, - use_variance=use_variance, - use_archive=use_archive, - ) - - # Ensure that no stamps have already been normalized. - bss._checkNormalization(True, innerRadius, outerRadius) - bss._innerRadius, bss._outerRadius = innerRadius, outerRadius - - # Apply normalization. - rejects = [] - badStamps = [] - for stamp in bss._stamps: - try: - stamp.measureAndNormalize( - annulus, statsControl=statsControl, statsFlag=statsFlag, badMaskPlanes=badMaskPlanes - ) - # Stars that are missing from input bright star stamps may - # still have a flux within the normalization annulus. The - # following two lines make sure that these stars are included - # in the subtraction process. Failing to assign the optimal - # radii values may result in an error in the `createAnnulus` - # method of the `SubtractBrightStarsTask` class. An alternative - # to handle this is to create two types of stamps that are - # missing from the input brightStarStamps object. One for those - # that have flux within the normalization annulus and another - # for those that do not have a flux within the normalization - # annulus. - stamp.optimalOuterRadius = outerRadius - stamp.optimalInnerRadius = innerRadius - except RuntimeError as err: - logger.error(err) - # Optionally keep NaN flux objects, for bookkeeping purposes, - # and to avoid having to re-find and redo the preprocessing - # steps needed before bright stars can be subtracted. - if discardNanFluxObjects: - rejects.append(stamp) - elif forceFindFlux: - newInnerRadius = innerRadius - newOuterRadius = outerRadius - while True: - newOuterRadius += annulusWidth - newInnerRadius += annulusWidth - if newOuterRadius > min(imCenter): - logger.info("No flux found for the star with Gaia ID of %s", stamp.gaiaId) - stamp.annularFlux = None - badStamps.append(stamp) - break - newOuterCircle = SpanSet.fromShape(newOuterRadius, Stencil.CIRCLE, offset=imCenter) - newInnerCircle = SpanSet.fromShape(newInnerRadius, Stencil.CIRCLE, offset=imCenter) - newAnnulus = newOuterCircle.intersectNot(newInnerCircle) - try: - stamp.measureAndNormalize( - newAnnulus, - statsControl=statsControl, - statsFlag=statsFlag, - badMaskPlanes=badMaskPlanes, - ) - - except RuntimeError: - stamp.annularFlux = np.nan - logger.error( - "The annular flux was not found for radii %d and %d", - newInnerRadius, - newOuterRadius, - ) - if stamp.annularFlux and stamp.annularFlux > 0: - logger.info("The flux is found within an optimized annulus.") - logger.info( - "The optimized annulus radii are %d and %d and the flux is %f", - newInnerRadius, - newOuterRadius, - stamp.annularFlux, - ) - stamp.optimalOuterRadius = newOuterRadius - stamp.optimalInnerRadius = newInnerRadius - break - else: - stamp.annularFlux = np.nan - - # Remove rejected stamps. - bss.normalized = True - if discardNanFluxObjects: - for reject in rejects: - bss._stamps.remove(reject) - elif forceFindFlux: - for badStamp in badStamps: - bss._stamps.remove(badStamp) - bss._innerRadius, bss._outerRadius = None, None - return bss, badStamps - return bss + super().__init__(starStamps, metadata, useMask=True, useVariance=True, useArchive=True) + self.byRefId = {stamp.refId: stamp for stamp in self} @classmethod def readFits(cls, filename): @@ -485,185 +167,5 @@ def readFitsWithOptions(cls, filename, options): options : `PropertyList` Collection of metadata parameters. """ - stamps, metadata = readFitsWithOptions(filename, BrightStarStamp.factory, options) - nb90Rots = metadata["NB_90_ROTS"] if "NB_90_ROTS" in metadata else None - # For backwards compatibility, always assume stamps are unnormalized. - # This allows older stamps to be read in successfully. - return cls( - stamps, - nb90Rots=nb90Rots, - metadata=metadata, - use_mask=metadata["HAS_MASK"], - use_variance=metadata["HAS_VARIANCE"], - use_archive=metadata["HAS_ARCHIVE"], - ) - - def append(self, item, innerRadius=None, outerRadius=None): - """Add an additional bright star stamp. - - Parameters - ---------- - item : `BrightStarStamp` - Bright star stamp to append. - innerRadius : `int`, optional - Inner radius value, in pixels. This and ``outerRadius`` define the - annulus used to compute the ``"annularFlux"`` values within each - ``BrightStarStamp``. - outerRadius : `int`, optional - Outer radius value, in pixels. This and ``innerRadius`` define the - annulus used to compute the ``"annularFlux"`` values within each - ``BrightStarStamp``. - """ - if not isinstance(item, BrightStarStamp): - raise ValueError(f"Can only add instances of BrightStarStamp, got {type(item)}.") - if (item.annularFlux is None) == self.normalized: - raise AttributeError( - "Trying to append an unnormalized stamp to a normalized BrightStarStamps " - "instance, or vice-versa." - ) - else: - self._checkRadius(innerRadius, outerRadius) - self._stamps.append(item) - return None - - def extend(self, bss): - """Extend BrightStarStamps instance by appending elements from another - instance. - - Parameters - ---------- - bss : `BrightStarStamps` - Other instance to concatenate. - """ - if not isinstance(bss, BrightStarStamps): - raise ValueError(f"Can only extend with a BrightStarStamps object. Got {type(bss)}.") - self._checkRadius(bss._innerRadius, bss._outerRadius) - self._stamps += bss._stamps - - def getMagnitudes(self): - """Retrieve Gaia G-band magnitudes for each star. - - Returns - ------- - gaiaGMags : `list` [`float`] - Gaia G-band magnitudes for each star. - """ - return [stamp.gaiaGMag for stamp in self._stamps] - - def getGaiaIds(self): - """Retrieve Gaia IDs for each star. - - Returns - ------- - gaiaIds : `list` [`int`] - Gaia IDs for each star. - """ - return [stamp.gaiaId for stamp in self._stamps] - - def getAnnularFluxes(self): - """Retrieve normalization factor for each star. - - These are computed by integrating the flux in annulus centered on the - bright star, far enough from center to be beyond most severe ghosts and - saturation. - The inner and outer radii that define the annulus can be recovered from - the metadata. - - Returns - ------- - annularFluxes : `list` [`float`] - Annular fluxes which give the normalization factor for each star. - """ - return [stamp.annularFlux for stamp in self._stamps] - - def getValidPixelsFraction(self): - """Retrieve the fraction of valid pixels within the normalization - annulus for each star. - - Returns - ------- - validPixelsFractions : `list` [`float`] - Fractions of valid pixels within the normalization annulus for each - star. - """ - return [stamp.validAnnulusFraction for stamp in self._stamps] - - def selectByMag(self, magMin=None, magMax=None): - """Return the subset of bright star stamps for objects with specified - magnitude cuts (in Gaia G). - - Parameters - ---------- - magMin : `float`, optional - Keep only stars fainter than this value. - magMax : `float`, optional - Keep only stars brighter than this value. - """ - subset = [ - stamp - for stamp in self._stamps - if (magMin is None or stamp.gaiaGMag > magMin) and (magMax is None or stamp.gaiaGMag < magMax) - ] - # This saves looping over init when guaranteed to be the correct type. - instance = BrightStarStamps( - (), innerRadius=self._innerRadius, outerRadius=self._outerRadius, metadata=self._metadata - ) - instance._stamps = subset - return instance - - def _checkRadius(self, innerRadius, outerRadius): - """Ensure provided annulus radius is consistent with that already - present in the instance, or with arguments passed on at initialization. - """ - if innerRadius != self._innerRadius or outerRadius != self._outerRadius: - raise AttributeError( - f"Trying to mix stamps normalized with annulus radii {innerRadius, outerRadius} with those " - "of BrightStarStamp instance\n" - f"(computed with annular radii {self._innerRadius, self._outerRadius})." - ) - - def _checkNormalization(self, normalize, innerRadius, outerRadius): - """Ensure there is no mixing of normalized and unnormalized stars, and - that, if requested, normalization can be performed. - """ - noneFluxCount = self.getAnnularFluxes().count(None) - nStamps = len(self) - nFluxVals = nStamps - noneFluxCount - if noneFluxCount and noneFluxCount < nStamps: - # At least one stamp contains an annularFlux value (i.e. has been - # normalized), but not all of them do. - raise AttributeError( - f"Only {nFluxVals} stamps contain an annularFlux value.\nAll stamps in a BrightStarStamps " - "instance must either be normalized with the same annulus definition, or none of them can " - "contain an annularFlux value." - ) - elif normalize: - # Stamps are to be normalized; ensure annular radii are specified - # and they have no annularFlux. - if innerRadius is None or outerRadius is None: - raise AttributeError( - "For stamps to be normalized (normalize=True), please provide a valid value (in pixels) " - "for both innerRadius and outerRadius." - ) - elif noneFluxCount < nStamps: - raise AttributeError( - f"{nFluxVals} stamps already contain an annularFlux value. For stamps to be normalized, " - "all their annularFlux must be None." - ) - elif innerRadius is not None and outerRadius is not None: - # Radii provided, but normalize=False; check that stamps already - # contain annularFluxes. - if noneFluxCount: - raise AttributeError( - f"{noneFluxCount} stamps contain no annularFlux, but annular radius values were provided " - "and normalize=False.\nTo normalize stamps, set normalize to True." - ) - else: - # At least one radius value is missing; ensure no stamps have - # already been normalized. - if nFluxVals: - raise AttributeError( - f"{nFluxVals} stamps contain an annularFlux value. If stamps have been normalized, the " - "innerRadius and outerRadius values used must be provided." - ) - return None + stamps, metadata = readFitsWithOptions(filename, BrightStarStamp, options) + return cls(stamps, metadata=metadata) diff --git a/python/lsst/meas/algorithms/stamps.py b/python/lsst/meas/algorithms/stamps.py index 2ee5d1fb6..d15ece031 100644 --- a/python/lsst/meas/algorithms/stamps.py +++ b/python/lsst/meas/algorithms/stamps.py @@ -19,11 +19,14 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from __future__ import annotations + """Collection of small images (postage stamps).""" -__all__ = ["StampBase", "Stamp", "StampsBase", "Stamps", "writeFits", "readFitsWithOptions"] +__all__ = ["StampBase", "StampsBase", "writeFits", "readFitsWithOptions"] import abc +import typing from collections.abc import Mapping, Sequence from dataclasses import dataclass, field, fields @@ -32,7 +35,7 @@ from lsst.afw.image import ImageFitsReader, MaskedImage, MaskedImageF, MaskFitsReader from lsst.afw.table.io import InputArchive, OutputArchive, Persistable from lsst.daf.base import PropertyList -from lsst.geom import Angle, Box2I, Extent2I, Point2I, SpherePoint, degrees +from lsst.geom import Angle, SpherePoint, degrees from lsst.utils import doImport from lsst.utils.introspection import get_full_type_name @@ -41,12 +44,12 @@ def writeFits( filename: str, - stamps: Sequence, + stamps: Sequence[StampBase], metadata: PropertyList, - type_name: str, - write_mask: bool, - write_variance: bool, - write_archive: bool = False, + typeName: str, + writeMask: bool, + writeVariance: bool, + writeArchive: bool = False, ): """Write a single FITS file containing all stamps. @@ -58,21 +61,21 @@ def writeFits( An iterable of Stamp objects. metadata : `PropertyList` A collection of key:value metadata pairs written to the primary header. - type_name : `str` + typeName : `str` Python type name of the StampsBase subclass to use. - write_mask : `bool` + writeMask : `bool` Write the mask data to the output file? - write_variance : `bool` + writeVariance : `bool` Write the variance data to the output file? - write_archive : `bool`, optional + writeArchive : `bool`, optional Write an archive which stores Persistables along with each stamp? """ # Stored metadata in the primary HDU - metadata["HAS_MASK"] = write_mask - metadata["HAS_VARIANCE"] = write_variance - metadata["HAS_ARCHIVE"] = write_archive + metadata["HAS_MASK"] = writeMask + metadata["HAS_VARIANCE"] = writeVariance + metadata["HAS_ARCHIVE"] = writeArchive metadata["N_STAMPS"] = len(stamps) - metadata["STAMPCLS"] = type_name + metadata["STAMPCLS"] = typeName metadata["VERSION"] = 2 # Record version number in case of future code changes # Create the primary HDU with global metadata @@ -80,49 +83,49 @@ def writeFits( fitsFile.createEmpty() # Store Persistables in an OutputArchive and write it to the primary HDU - if write_archive: - archive_element_ids = [] + if writeArchive: + stamps_archiveElementNames = set() + stamps_archiveElementIds = [] oa = OutputArchive() - archive_element_names = set() for stamp in stamps: - stamp_archive_elements = stamp._getArchiveElements() - archive_element_names.update(stamp_archive_elements.keys()) - archive_element_ids.append( - {name: oa.put(persistable) for name, persistable in stamp_archive_elements.items()} + stamp_archiveElements = stamp._getArchiveElements() + stamps_archiveElementNames.update(stamp_archiveElements.keys()) + stamps_archiveElementIds.append( + {name: oa.put(persistable) for name, persistable in stamp_archiveElements.items()} ) fitsFile.writeMetadata(metadata) oa.writeFits(fitsFile) else: - archive_element_ids = [None] * len(stamps) + stamps_archiveElementIds = [None] * len(stamps) fitsFile.writeMetadata(metadata) fitsFile.closeFile() # Add all pixel data to extension HDUs; optionally write mask/variance info - for i, (stamp, stamp_archive_element_ids) in enumerate(zip(stamps, archive_element_ids)): + for i, (stamp, stamp_archiveElementIds) in enumerate(zip(stamps, stamps_archiveElementIds)): metadata = PropertyList() extVer = i + 1 # EXTVER should be 1-based; the index from enumerate is 0-based metadata.update({"EXTVER": extVer, "EXTNAME": "IMAGE"}) - if stamp_metadata := stamp._getMetadata(): - metadata.update(stamp_metadata) - if stamp_archive_element_ids: - metadata.update(stamp_archive_element_ids) - for archive_element_name in sorted(archive_element_names): - metadata.add("ARCHIVE_ELEMENT", archive_element_name) - stamp.stamp_im.getImage().writeFits(filename, metadata=metadata, mode="a") - if write_mask: + if stampMetadata := stamp._getMetadata(): + metadata.update(stampMetadata) + if stamp_archiveElementIds: + metadata.update(stamp_archiveElementIds) + for stamps_archiveElementName in sorted(stamps_archiveElementNames): + metadata.add("ARCHIVE_ELEMENT", stamps_archiveElementName) + stamp.maskedImage.getImage().writeFits(filename, metadata=metadata, mode="a") + if writeMask: metadata = PropertyList() metadata.update({"EXTVER": extVer, "EXTNAME": "MASK"}) - stamp.stamp_im.getMask().writeFits(filename, metadata=metadata, mode="a") - if write_variance: + stamp.maskedImage.getMask().writeFits(filename, metadata=metadata, mode="a") + if writeVariance: metadata = PropertyList() metadata.update({"EXTVER": extVer, "EXTNAME": "VARIANCE"}) - stamp.stamp_im.getVariance().writeFits(filename, metadata=metadata, mode="a") + stamp.maskedImage.getVariance().writeFits(filename, metadata=metadata, mode="a") return None def readFitsWithOptions( filename: str, - stamp_factory: classmethod, + stamp_cls: type[StampBase], options: PropertyList, ): """Read stamps from FITS file, allowing for only a subregion of the stamps @@ -132,15 +135,11 @@ def readFitsWithOptions( ---------- filename : `str` A string indicating the file to read - stamp_factory : classmethod + stampFactory : classmethod A factory function defined on a dataclass for constructing stamp objects a la `~lsst.meas.algorithm.Stamp` options : `PropertyList` or `dict` - A collection of parameters. If it contains a bounding box - (``bbox`` key), or if certain other keys (``llcX``, ``llcY``, - ``width``, ``height``) are available for one to be constructed, - the bounding box is passed to the ``FitsReader`` in order to - return a sub-image. + A collection of parameters. Returns ------- @@ -158,68 +157,50 @@ def readFitsWithOptions( # Extract necessary info from metadata metadata = readMetadata(filename, hdu=0) nStamps = metadata["N_STAMPS"] - has_archive = metadata["HAS_ARCHIVE"] - archive_element_names = None - archive_element_ids_v1 = None - if has_archive: + hasArchive = metadata["HAS_ARCHIVE"] + stamps_archiveElementNames = None + stamps_archiveElementIds_v1 = None + if hasArchive: if metadata["VERSION"] < 2: - archive_element_ids_v1 = metadata.getArray("ARCHIVE_IDS") + stamps_archiveElementIds_v1 = metadata.getArray("ARCHIVE_IDS") else: - archive_element_names = metadata.getArray("ARCHIVE_ELEMENT") + stamps_archiveElementNames = stamp_cls._getArchiveElementNames() with Fits(filename, "r") as fitsFile: nExtensions = fitsFile.countHdus() - # check if a bbox was provided - kwargs = {} - if options: - # gen3 API - if "bbox" in options.keys(): - kwargs["bbox"] = options["bbox"] - # gen2 API - elif "llcX" in options.keys(): - llcX = options["llcX"] - llcY = options["llcY"] - width = options["width"] - height = options["height"] - bbox = Box2I(Point2I(llcX, llcY), Extent2I(width, height)) - kwargs["bbox"] = bbox - stamp_parts = {} + stampParts = {} # Determine the dtype from the factory. # This allows a Stamp class to be defined in terms of MaskedImageD or # MaskedImageI without forcing everything to floats. - masked_image_cls = None - for stamp_field in fields(stamp_factory.__self__): - if issubclass(stamp_field.type, MaskedImage): - masked_image_cls = stamp_field.type - break - else: - raise RuntimeError("Stamp factory does not use MaskedImage.") - default_dtype = np.dtype(masked_image_cls.dtype) + maskedImageCls = stamp_cls._getMaskedImageClass() + default_dtype = np.dtype(maskedImageCls.dtype) variance_dtype = np.dtype(np.float32) # Variance is always the same type # We need to be careful because nExtensions includes the primary HDU - stamp_metadata = {} - archive_element_ids = {} + stampMetadata = {} + stamps_archiveElementIds = {} for idx in range(nExtensions - 1): dtype = None hduNum = idx + 1 md = readMetadata(filename, hdu=hduNum) if md["EXTNAME"] in ("IMAGE", "VARIANCE"): + stampId = md["EXTVER"] reader = ImageFitsReader(filename, hdu=hduNum) if md["EXTNAME"] == "VARIANCE": dtype = variance_dtype else: dtype = default_dtype - if archive_element_names is not None: - archive_element_ids[idx] = { - name: archive_id - for name in archive_element_names - if (archive_id := md.pop(name, None)) + if stamps_archiveElementNames is not None: + stamps_archiveElementIds[stampId] = { + name: archiveId + for name in stamps_archiveElementNames + if (archiveId := md.pop(name, None)) } # md.remove("EXTNAME") # md.remove("EXTVER") - stamp_metadata[idx] = md + stampMetadata[stampId] = md elif md["EXTNAME"] == "MASK": + stampId = md["EXTVER"] reader = MaskFitsReader(filename, hdu=hduNum) elif md["EXTNAME"] == "ARCHIVE_INDEX": fitsFile.setHdu(hduNum) @@ -229,36 +210,36 @@ def readFitsWithOptions( continue else: raise ValueError(f"Unknown extension type: {md['EXTNAME']}") - stamp_parts.setdefault(md["EXTVER"], {})[md["EXTNAME"].lower()] = reader.read( - dtype=dtype, **kwargs - ) + stampParts.setdefault(stampId, {})[md["EXTNAME"].lower()] = reader.read(dtype=dtype) - if len(stamp_parts) != nStamps: + if len(stampParts) != nStamps: raise ValueError( - f"Number of stamps read ({len(stamp_parts)}) does not agree with the " + f"Number of stamps read ({len(stampParts)}) does not agree with the " f"number of stamps recorded in the metadata ({nStamps})." ) # Construct the stamps themselves stamps = [] for k in range(nStamps): # Need to increment by one since EXTVER starts at 1 - maskedImage = masked_image_cls(**stamp_parts[k + 1]) - if archive_element_ids_v1 is not None: - archive_elements = {DEFAULT_ARCHIVE_ELEMENT_NAME: archive.get(archive_element_ids_v1[k])} - elif archive_element_names is not None: - stamp_archive_element_ids = archive_element_ids.get(k, {}) - archive_elements = {name: archive.get(id) for name, id in stamp_archive_element_ids.items()} + maskedImage = maskedImageCls(**stampParts[k + 1]) + if stamps_archiveElementIds_v1 is not None: + stamp_archiveElements = { + DEFAULT_ARCHIVE_ELEMENT_NAME: archive.get(stamps_archiveElementIds_v1[k]) + } + elif stamps_archiveElementNames is not None: + stamp_archiveElementIds = stamps_archiveElementIds.get(k + 1, {}) + stamp_archiveElements = {name: archive.get(id) for name, id in stamp_archiveElementIds.items()} else: - archive_elements = None + stamp_archiveElements = None if metadata["VERSION"] < 2: - stamps.append(stamp_factory(maskedImage, metadata, k, archive_elements)) + stamps.append(stamp_cls.factory(maskedImage, metadata, k, stamp_archiveElements)) else: - stamps.append(stamp_factory(maskedImage, stamp_metadata[k], k, archive_elements)) + stamps.append(stamp_cls.factory(maskedImage, stampMetadata[k + 1], k, stamp_archiveElements)) return stamps, metadata -def _default_position(): +def _defaultPosition(): # SpherePoint is nominally mutable in C++ so we must use a factory # and return an entirely new SpherePoint each time a Stamps is created. return SpherePoint(Angle(np.nan), Angle(np.nan)) @@ -266,22 +247,32 @@ def _default_position(): @dataclass class StampBase(abc.ABC): - """Single abstract stamp. + """Single abstract postage stamp. - Parameters - ---------- - Inherit from this class to add metadata to the stamp. + Notes + ----- + Inherit from this class to add metadata to the postage stamp. """ + @classmethod + @abc.abstractmethod + def _getMaskedImageClass(cls) -> type[MaskedImage]: + """Return the class of the MaskedImage object to be used.""" + raise NotImplementedError() + + @classmethod + def _getArchiveElementNames(cls) -> list[str]: + return [] + @classmethod @abc.abstractmethod def factory( cls, - stamp_im: MaskedImage, + maskedImage: MaskedImageF, metadata: PropertyList, index: int, - archive_elements: Mapping[str, Persistable] | None = None, - ): + archiveElements: Mapping[str, Persistable] | None = None, + ) -> typing.Self: """This method is needed to service the FITS reader. We need a standard interface to construct objects like this. Parameters needed to construct this object are passed in via a metadata @@ -289,14 +280,13 @@ def factory( Parameters ---------- - stamp_im : `~lsst.afw.image.MaskedImage` + maskedImage : `~lsst.afw.image.MaskedImageF` Pixel data to pass to the constructor metadata : `PropertyList` - Dictionary containing the information - needed by the constructor. - idx : `int` + Dictionary containing the information needed by the constructor. + index : `int` Index into the lists in ``metadata`` - archive_elements : `~collections.abc.Mapping`[ `str` , \ + archiveElements : `~collections.abc.Mapping`[ `str` , \ `~lsst.afw.table.io.Persistable`], optional Archive elements (e.g. Transform / WCS) associated with this stamp. @@ -326,89 +316,6 @@ def _getMetadata(self) -> PropertyList | None: return None -@dataclass -class Stamp(StampBase): - """Single stamp. - - Parameters - ---------- - stamp_im : `~lsst.afw.image.MaskedImageF` - The actual pixel values for the postage stamp. - archive_element : `~lsst.afw.table.io.Persistable` or `None`, optional - Archive element (e.g. Transform or WCS) associated with this stamp. - position : `~lsst.geom.SpherePoint` or `None`, optional - Position of the center of the stamp. Note the user must keep track of - the coordinate system. - """ - - stamp_im: MaskedImageF - archive_element: Persistable | None = None - position: SpherePoint | None = field(default_factory=_default_position) - - @classmethod - def factory(cls, stamp_im: MaskedImage, metadata: PropertyList, index: int, archive_elements=None): - """This method is needed to service the FITS reader. We need a standard - interface to construct objects like this. Parameters needed to - construct this object are passed in via a metadata dictionary and then - passed to the constructor of this class. If lists of values are passed - with the following keys, they will be passed to the constructor, - otherwise dummy values will be passed: RA_DEG, DEC_DEG. They should - each point to lists of values. - - Parameters - ---------- - stamp : `~lsst.afw.image.MaskedImage` - Pixel data to pass to the constructor - metadata : `dict` - Dictionary containing the information - needed by the constructor. - idx : `int` - Index into the lists in ``metadata`` - archive_elements : `~collections.abc.Mapping`[ `str` , \ - `~lsst.afw.table.io.Persistable`], optional - Archive elements (e.g. Transform / WCS) associated with this stamp. - - Returns - ------- - stamp : `Stamp` - An instance of this class - """ - if archive_elements: - try: - (archive_element,) = archive_elements.values() - except TypeError: - raise RuntimeError("Expected exactly one archive element.") - else: - archive_element = None - - if "RA_DEG" in metadata and "DEC_DEG" in metadata: - return cls( - stamp_im=stamp_im, - archive_element=archive_element, - position=SpherePoint( - Angle(metadata.getArray("RA_DEG")[index], degrees), - Angle(metadata.getArray("DEC_DEG")[index], degrees), - ), - ) - else: - return cls( - stamp_im=stamp_im, - archive_element=archive_element, - position=SpherePoint(Angle(np.nan), Angle(np.nan)), - ) - - def _getMaskedImage(self): - return self.stamp_im - - def _getArchiveElements(self): - return {DEFAULT_ARCHIVE_ELEMENT_NAME: self.archive_element} - - def _getMetadata(self): - md = PropertyList() - md["RA_DEG"] = self.position.getRa().asDegrees() - md["DEC_DEG"] = self.position.getDec().asDegrees() - - class StampsBase(abc.ABC, Sequence): """Collection of stamps and associated metadata. @@ -419,43 +326,32 @@ class StampsBase(abc.ABC, Sequence): a la ``~lsst.meas.algorithms.Stamp``. metadata : `~lsst.daf.base.PropertyList`, optional Metadata associated with the objects within the stamps. - use_mask : `bool`, optional + useMask : `bool`, optional If ``True`` read and write the mask data. Default ``True``. - use_variance : `bool`, optional + useVariance : `bool`, optional If ``True`` read and write the variance data. Default ``True``. - use_archive : `bool`, optional + useArchive : `bool`, optional If ``True``, read and write an Archive that contains a Persistable associated with each stamp, for example a Transform or a WCS. Default ``False``. - - Notes - ----- - A butler can be used to read only a part of the stamps, - specified by a bbox: - - >>> starSubregions = butler.get( - "brightStarStamps", - dataId, - parameters={"bbox": bbox} - ) """ def __init__( self, - stamps: list, + stamps: Sequence[StampBase], metadata: PropertyList | None = None, - use_mask: bool = True, - use_variance: bool = True, - use_archive: bool = False, + useMask: bool = True, + useVariance: bool = True, + useArchive: bool = False, ): for stamp in stamps: if not isinstance(stamp, StampBase): raise ValueError(f"The entries in stamps must inherit from StampBase. Got {type(stamp)}.") - self._stamps = stamps + self._stamps = list(stamps) self._metadata = PropertyList() if metadata is None else metadata.deepCopy() - self.use_mask = use_mask - self.use_variance = use_variance - self.use_archive = use_archive + self.useMask = useMask + self.useVariance = useVariance + self.useArchive = useArchive @classmethod def readFits(cls, filename: str): @@ -489,16 +385,16 @@ def readFitsWithOptions(cls, filename: str, options: PropertyList): # Load metadata to get the class metadata = readMetadata(filename, hdu=0) - type_name = metadata.get("STAMPCLS") - if type_name is None: + typeName = metadata.get("STAMPCLS") + if typeName is None: raise RuntimeError( f"No class name in file {filename}. Unable to instantiate correct stamps subclass. " "Is this an old version format Stamps file?" ) # Import class and override `cls` - stamp_type = doImport(type_name) - cls = stamp_type + stampType = doImport(typeName) + cls = stampType return cls.readFitsWithOptions(filename, options) @@ -510,15 +406,15 @@ def writeFits(self, filename: str): filename : `str` Name of the FITS file to write. """ - type_name = get_full_type_name(self) + typeName = get_full_type_name(self) writeFits( filename=filename, stamps=self._stamps, metadata=self._metadata, - type_name=type_name, - write_mask=self.use_mask, - write_variance=self.use_variance, - write_archive=self.use_archive, + typeName=typeName, + writeMask=self.useMask, + writeVariance=self.useVariance, + writeArchive=self.useArchive, ) def __len__(self): @@ -530,100 +426,6 @@ def __getitem__(self, index): def __iter__(self): return iter(self._stamps) - def getMaskedImages(self): - """Retrieve star images. - - Returns - ------- - maskedImages : - `list` [`~lsst.afw.image.MaskedImageF`] - """ - return [stamp._getMaskedImage() for stamp in self._stamps] - - def getArchiveElements(self): - """Retrieve archive elements associated with each stamp. - - Returns - ------- - archiveElements : `list` [`dict`[ `str`, \ - `~lsst.afw.table.io.Persistable` ]] - A list of archive elements associated with each stamp. - """ - return [stamp._getArchiveElements() for stamp in self._stamps] - @property def metadata(self): return self._metadata - - -class Stamps(StampsBase): - - def getPositions(self): - return [stamp.position for stamp in self._stamps] - - def append(self, item: Stamp): - """Add an additional stamp. - - Parameters - ---------- - item : `Stamp` - Stamp object to append. - """ - if not isinstance(item, Stamp): - raise ValueError("Objects added must be a Stamp object.") - self._stamps.append(item) - return None - - def extend(self, stamp_list: list[Stamp]): - """Extend Stamps instance by appending elements from another instance. - - Parameters - ---------- - stamps_list : `list` [`Stamp`] - List of Stamp object to append. - """ - for stamp in stamp_list: - if not isinstance(stamp, Stamp): - raise ValueError("Can only extend with Stamp objects") - self._stamps += stamp_list - - @classmethod - def readFits(cls, filename: str): - """Build an instance of this class from a file. - - Parameters - ---------- - filename : `str` - Name of the file to read. - - Returns - ------- - object : `Stamps` - An instance of this class. - """ - return cls.readFitsWithOptions(filename, None) - - @classmethod - def readFitsWithOptions(cls, filename: str, options: PropertyList): - """Build an instance of this class with options. - - Parameters - ---------- - filename : `str` - Name of the file to read. - options : `PropertyList` or `dict` - Collection of metadata parameters. - - Returns - ------- - object : `Stamps` - An instance of this class. - """ - stamps, metadata = readFitsWithOptions(filename, Stamp.factory, options) - return cls( - stamps, - metadata=metadata, - use_mask=metadata["HAS_MASK"], - use_variance=metadata["HAS_VARIANCE"], - use_archive=metadata["HAS_ARCHIVE"], - )