diff --git a/python/lsst/meas/algorithms/stamps.py b/python/lsst/meas/algorithms/stamps.py index 9352326bd..dcf7bcdae 100644 --- a/python/lsst/meas/algorithms/stamps.py +++ b/python/lsst/meas/algorithms/stamps.py @@ -19,12 +19,12 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -"""Collection of small images (stamps).""" +"""Collection of small images (postage stamps).""" -__all__ = ["Stamp", "Stamps", "StampsBase", "writeFits", "readFitsWithOptions"] +__all__ = ["StampBase", "Stamp", "StampsBase", "Stamps", "writeFits", "readFitsWithOptions"] import abc -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from dataclasses import dataclass, field, fields import numpy as np @@ -36,67 +36,94 @@ from lsst.utils import doImport from lsst.utils.introspection import get_full_type_name +DEFAULT_ARCHIVE_ELEMENT_NAME = "ELEMENT" -def writeFits(filename, stamps, metadata, type_name, write_mask, write_variance, write_archive=False): + +def writeFits( + filename: str, + stamps: Sequence, + metadata: PropertyList, + type_name: str, + write_mask: bool, + write_variance: bool, + write_archive: bool = False, +): """Write a single FITS file containing all stamps. Parameters ---------- filename : `str` - A string indicating the output filename - stamps : iterable of `BaseStamp` - An iterable of Stamp objects + A string indicating the output filename. + stamps : iterable of `StampBase` + An iterable of Stamp objects. metadata : `PropertyList` - A collection of key, value metadata pairs to be - written to the primary header + A collection of key:value metadata pairs written to the primary header. type_name : `str` - Python type name of the StampsBase subclass to use + Python type name of the StampsBase subclass to use. write_mask : `bool` Write the mask data to the output file? write_variance : `bool` Write the variance data to the output file? write_archive : `bool`, optional - Write an archive to store Persistables along with each stamp? - Default: ``False``. + Write an archive which stores Persistables along with each stamp? """ + # Stored metadata in the primary HDU metadata["HAS_MASK"] = write_mask metadata["HAS_VARIANCE"] = write_variance metadata["HAS_ARCHIVE"] = write_archive metadata["N_STAMPS"] = len(stamps) metadata["STAMPCLS"] = type_name - # Record version number in case of future code changes - metadata["VERSION"] = 1 - # create primary HDU with global metadata + metadata["VERSION"] = 2 # Record version number in case of future code changes + + # Create the primary HDU with global metadata fitsFile = Fits(filename, "w") fitsFile.createEmpty() - # Store Persistables in an OutputArchive and write it + + # Store Persistables in an OutputArchive and write it to the primary HDU if write_archive: + archive_element_ids = [] oa = OutputArchive() - archive_ids = [oa.put(stamp.archive_element) for stamp in stamps] - metadata["ARCHIVE_IDS"] = archive_ids + archive_element_names = set() + for stamp in stamps: + stamp_archive_elements = stamp._getArchiveElements() + archive_element_names.update(stamp_archive_elements.keys()) + archive_element_ids.append( + {name: oa.put(persistable) for name, persistable in stamp_archive_elements.items()} + ) fitsFile.writeMetadata(metadata) oa.writeFits(fitsFile) else: + archive_element_ids = [None] * len(stamps) fitsFile.writeMetadata(metadata) fitsFile.closeFile() - # add all pixel data optionally writing mask and variance information - for i, stamp in enumerate(stamps): + + # Add all pixel data to extension HDUs; optionally write mask/variance info + for i, (stamp, stamp_archive_element_ids) in enumerate(zip(stamps, archive_element_ids)): metadata = PropertyList() - # EXTVER should be 1-based, the index from enumerate is 0-based - metadata.update({"EXTVER": i + 1, "EXTNAME": "IMAGE"}) + extVer = i + 1 # EXTVER should be 1-based; the index from enumerate is 0-based + metadata.update({"EXTVER": extVer, "EXTNAME": "IMAGE"}) + # CALL GET_METADATA - DO METADATA.UPDATE + if stamp_archive_element_ids: + metadata.update(stamp_archive_element_ids) + for archive_element_name in sorted(archive_element_names): + metadata.add("ARCHIVE_ELEMENT", archive_element_name) stamp.stamp_im.getImage().writeFits(filename, metadata=metadata, mode="a") if write_mask: metadata = PropertyList() - metadata.update({"EXTVER": i + 1, "EXTNAME": "MASK"}) + metadata.update({"EXTVER": extVer, "EXTNAME": "MASK"}) stamp.stamp_im.getMask().writeFits(filename, metadata=metadata, mode="a") if write_variance: metadata = PropertyList() - metadata.update({"EXTVER": i + 1, "EXTNAME": "VARIANCE"}) + metadata.update({"EXTVER": extVer, "EXTNAME": "VARIANCE"}) stamp.stamp_im.getVariance().writeFits(filename, metadata=metadata, mode="a") return None -def readFitsWithOptions(filename, stamp_factory, options): +def readFitsWithOptions( + filename: str, + stamp_factory: classmethod, + options: PropertyList, +): """Read stamps from FITS file, allowing for only a subregion of the stamps to be read. @@ -124,15 +151,20 @@ def readFitsWithOptions(filename, stamp_factory, options): Notes ----- The data are read using the data type expected by the - `~lsst.afw.image.MaskedImage` class attached to the `AbstractStamp` + `~lsst.afw.image.MaskedImage` class attached to the `StampBase` dataclass associated with the factory method. """ - # extract necessary info from metadata + # Extract necessary info from metadata metadata = readMetadata(filename, hdu=0) nStamps = metadata["N_STAMPS"] has_archive = metadata["HAS_ARCHIVE"] + archive_element_names = None + archive_element_ids_v1 = None if has_archive: - archive_ids = metadata.getArray("ARCHIVE_IDS") + if metadata["VERSION"] < 2: + archive_element_ids_v1 = metadata.getArray("ARCHIVE_ELEMENT_IDS") + else: + archive_element_names = metadata.getArray("ARCHIVE_ELEMENT") with Fits(filename, "r") as f: nExtensions = f.countHdus() # check if a bbox was provided @@ -151,9 +183,9 @@ def readFitsWithOptions(filename, stamp_factory, options): kwargs["bbox"] = bbox stamp_parts = {} - # Determine the dtype from the factory. This allows a Stamp class - # to be defined in terms of MaskedImageD or MaskedImageI without - # forcing everything to floats. + # Determine the dtype from the factory. + # This allows a Stamp class to be defined in terms of MaskedImageD or + # MaskedImageI without forcing everything to floats. masked_image_cls = None for stamp_field in fields(stamp_factory.__self__): if issubclass(stamp_field.type, MaskedImage): @@ -162,9 +194,10 @@ def readFitsWithOptions(filename, stamp_factory, options): else: raise RuntimeError("Stamp factory does not use MaskedImage.") default_dtype = np.dtype(masked_image_cls.dtype) - variance_dtype = np.dtype(np.float32) # Variance is always the same type. + variance_dtype = np.dtype(np.float32) # Variance is always the same type - # We need to be careful because nExtensions includes the primary HDU. + # We need to be careful because nExtensions includes the primary HDU + archive_ids = {} for idx in range(nExtensions - 1): dtype = None md = readMetadata(filename, hdu=idx + 1) @@ -174,6 +207,13 @@ def readFitsWithOptions(filename, stamp_factory, options): dtype = variance_dtype else: dtype = default_dtype + if archive_element_names is not None: + archive_ids[idx] = { + name: md[name] for name in archive_element_names if name in md.keys() + } + # EG: md.pop(name) + # READ METADATA AND PUT IN ANOTHER MAPPING, A LA ARCHIVE_IDS + # MAKE SURE EXTNAME DOESN'T GET ADDED HERE elif md["EXTNAME"] == "MASK": reader = MaskFitsReader(filename, hdu=idx + 1) elif md["EXTNAME"] == "ARCHIVE_INDEX": @@ -184,26 +224,39 @@ def readFitsWithOptions(filename, stamp_factory, options): continue else: raise ValueError(f"Unknown extension type: {md['EXTNAME']}") - stamp_parts.setdefault(md["EXTVER"], {})[md["EXTNAME"].lower()] = reader.read(dtype=dtype, - **kwargs) + stamp_parts.setdefault(md["EXTVER"], {})[md["EXTNAME"].lower()] = reader.read( + dtype=dtype, **kwargs + ) + if len(stamp_parts) != nStamps: raise ValueError( f"Number of stamps read ({len(stamp_parts)}) does not agree with the " f"number of stamps recorded in the metadata ({nStamps})." ) - # construct stamps themselves + # Construct the stamps themselves stamps = [] for k in range(nStamps): # Need to increment by one since EXTVER starts at 1 maskedImage = masked_image_cls(**stamp_parts[k + 1]) - archive_element = archive.get(archive_ids[k]) if has_archive else None - stamps.append(stamp_factory(maskedImage, metadata, k, archive_element)) + if archive_element_ids_v1 is not None: + archive_elements = {DEFAULT_ARCHIVE_ELEMENT_NAME: archive.get(archive_element_ids_v1[k])} + elif archive_element_names is not None: + stamp_archive_ids = archive_ids.get(k, {}) + archive_elements = {name: archive.get(id) for name, id in stamp_archive_ids.items()} + # EXTRACT METADATA FROM THE MAPPING + stamps.append(stamp_factory(maskedImage, metadata, k, archive_elements)) return stamps, metadata +def _default_position(): + # SpherePoint is nominally mutable in C++ so we must use a factory + # and return an entirely new SpherePoint each time a Stamps is created. + return SpherePoint(Angle(np.nan), Angle(np.nan)) + + @dataclass -class AbstractStamp(abc.ABC): +class StampBase(abc.ABC): """Single abstract stamp. Parameters @@ -213,40 +266,55 @@ class AbstractStamp(abc.ABC): @classmethod @abc.abstractmethod - def factory(cls, stamp_im, metadata, index, archive_element=None): - """This method is needed to service the FITS reader. We need a standard - interface to construct objects like this. Parameters needed to - construct this object are passed in via a metadata dictionary and then - passed to the constructor of this class. + def factory( + cls, + stamp_im: MaskedImage, + metadata: PropertyList, + index: int, + archive_elements: Mapping[str, Persistable] | None = None, + ): + """This method is needed to service the FITS reader. + We need a standard interface to construct objects like this. + Parameters needed to construct this object are passed in via a metadata + dictionary and then passed to the constructor of this class. Parameters ---------- - stamp : `~lsst.afw.image.MaskedImage` + stamp_im : `~lsst.afw.image.MaskedImage` Pixel data to pass to the constructor metadata : `dict` Dictionary containing the information needed by the constructor. idx : `int` Index into the lists in ``metadata`` - archive_element : `~lsst.afw.table.io.Persistable`, optional - Archive element (e.g. Transform or WCS) associated with this stamp. + archive_elements : `~collections.abc.Mapping`[ `str` , \ + `~lsst.afw.table.io.Persistable`], optional + Archive elements (e.g. Transform / WCS) associated with this stamp. Returns ------- - stamp : `AbstractStamp` + stamp : `StampBase` An instance of this class """ - raise NotImplementedError + raise NotImplementedError() + @abc.abstractmethod + def _getMaskedImage(self): + """Return the image data.""" + raise NotImplementedError() -def _default_position(): - # SpherePoint is nominally mutable in C++ so we must use a factory - # and return an entirely new SpherePoint each time a Stamps is created. - return SpherePoint(Angle(np.nan), Angle(np.nan)) + @abc.abstractmethod + def _getArchiveElements(self): + """Return the archive elements. + + Keys should be upper case names that will be used directly as FITS + header keys. + """ + raise NotImplementedError() @dataclass -class Stamp(AbstractStamp): +class Stamp(StampBase): """Single stamp. Parameters @@ -265,7 +333,7 @@ class Stamp(AbstractStamp): position: SpherePoint | None = field(default_factory=_default_position) @classmethod - def factory(cls, stamp_im, metadata, index, archive_element=None): + def factory(cls, stamp_im: MaskedImage, metadata: PropertyList, index: int, archive_elements=None): """This method is needed to service the FITS reader. We need a standard interface to construct objects like this. Parameters needed to construct this object are passed in via a metadata dictionary and then @@ -283,14 +351,23 @@ def factory(cls, stamp_im, metadata, index, archive_element=None): needed by the constructor. idx : `int` Index into the lists in ``metadata`` - archive_element : `~lsst.afw.table.io.Persistable`, optional - Archive element (e.g. Transform or WCS) associated with this stamp. + archive_elements : `~collections.abc.Mapping`[ `str` , \ + `~lsst.afw.table.io.Persistable`], optional + Archive elements (e.g. Transform / WCS) associated with this stamp. Returns ------- stamp : `Stamp` An instance of this class """ + if archive_elements: + try: + (archive_element,) = archive_elements.values() + except TypeError: + raise RuntimeError("Expected exactly one archive element.") + else: + archive_element = None + if "RA_DEG" in metadata and "DEC_DEG" in metadata: return cls( stamp_im=stamp_im, @@ -307,6 +384,12 @@ def factory(cls, stamp_im, metadata, index, archive_element=None): position=SpherePoint(Angle(np.nan), Angle(np.nan)), ) + def _getMaskedImage(self): + return self.stamp_im + + def _getArchiveElements(self): + return {DEFAULT_ARCHIVE_ELEMENT_NAME: self.archive_element} + class StampsBase(abc.ABC, Sequence): """Collection of stamps and associated metadata. @@ -339,10 +422,17 @@ class StampsBase(abc.ABC, Sequence): ) """ - def __init__(self, stamps, metadata=None, use_mask=True, use_variance=True, use_archive=False): + def __init__( + self, + stamps: list, + metadata: PropertyList | None = None, + use_mask: bool = True, + use_variance: bool = True, + use_archive: bool = False, + ): for stamp in stamps: - if not isinstance(stamp, AbstractStamp): - raise ValueError(f"The entries in stamps must inherit from AbstractStamp. Got {type(stamp)}.") + if not isinstance(stamp, StampBase): + raise ValueError(f"The entries in stamps must inherit from StampBase. Got {type(stamp)}.") self._stamps = stamps self._metadata = PropertyList() if metadata is None else metadata.deepCopy() self.use_mask = use_mask @@ -350,29 +440,28 @@ def __init__(self, stamps, metadata=None, use_mask=True, use_variance=True, use_ self.use_archive = use_archive @classmethod - def readFits(cls, filename): + def readFits(cls, filename: str): """Build an instance of this class from a file. Parameters ---------- filename : `str` - Name of the file to read + Name of the file to read. """ - - return cls.readFitsWithOptions(filename, None) + return cls.readFitsWithOptions(filename=filename, options=None) @classmethod - def readFitsWithOptions(cls, filename, options): - """Build an instance of this class with options. + def readFitsWithOptions(cls, filename: str, options: PropertyList): + """Build an instance of this class from a file, with options. Parameters ---------- filename : `str` - Name of the file to read + Name of the file to read. options : `PropertyList` - Collection of metadata parameters + Collection of metadata parameters. """ - # To avoid problems since this is no longer an abstract method. + # To avoid problems since this is no longer an abstract base method. # TO-DO: Consider refactoring this method. This class check was added # to allow the butler formatter to use a generic type but still end up # giving the correct type back, ensuring that the abstract base class @@ -380,7 +469,7 @@ def readFitsWithOptions(cls, filename, options): if cls is not StampsBase: raise NotImplementedError(f"Please implement specific FITS reader for class {cls}") - # Load metadata to get class + # Load metadata to get the class metadata = readMetadata(filename, hdu=0) type_name = metadata.get("STAMPCLS") if type_name is None: @@ -400,24 +489,24 @@ def _refresh_metadata(self): """Make sure metadata is up to date, as this object can be extended.""" raise NotImplementedError - def writeFits(self, filename): - """Write this object to a file. + def writeFits(self, filename: str): + """Write this object to a FITS file. Parameters ---------- filename : `str` - Name of file to write. + Name of the FITS file to write. """ self._refresh_metadata() type_name = get_full_type_name(self) writeFits( - filename, - self._stamps, - self._metadata, - type_name, - self.use_mask, - self.use_variance, - self.use_archive, + filename=filename, + stamps=self._stamps, + metadata=self._metadata, + type_name=type_name, + write_mask=self.use_mask, + write_variance=self.use_variance, + write_archive=self.use_archive, ) def __len__(self): @@ -437,17 +526,18 @@ def getMaskedImages(self): maskedImages : `list` [`~lsst.afw.image.MaskedImageF`] """ - return [stamp.stamp_im for stamp in self._stamps] + return [stamp._getMaskedImage() for stamp in self._stamps] def getArchiveElements(self): """Retrieve archive elements associated with each stamp. Returns ------- - archiveElements : - `list` [`~lsst.afw.table.io.Persistable`] + archiveElements : `list` [`dict`[ `str`, \ + `~lsst.afw.table.io.Persistable` ]] + A list of archive elements associated with each stamp. """ - return [stamp.archive_element for stamp in self._stamps] + return [stamp._getArchiveElements() for stamp in self._stamps] @property def metadata(self): @@ -461,9 +551,9 @@ def _refresh_metadata(self): self._metadata["DEC_DEG"] = [p.getDec().asDegrees() for p in positions] def getPositions(self): - return [s.position for s in self._stamps] + return [stamp.position for stamp in self._stamps] - def append(self, item): + def append(self, item: Stamp): """Add an additional stamp. Parameters @@ -476,7 +566,7 @@ def append(self, item): self._stamps.append(item) return None - def extend(self, stamp_list): + def extend(self, stamp_list: list[Stamp]): """Extend Stamps instance by appending elements from another instance. Parameters @@ -484,13 +574,13 @@ def extend(self, stamp_list): stamps_list : `list` [`Stamp`] List of Stamp object to append. """ - for s in stamp_list: - if not isinstance(s, Stamp): + for stamp in stamp_list: + if not isinstance(stamp, Stamp): raise ValueError("Can only extend with Stamp objects") self._stamps += stamp_list @classmethod - def readFits(cls, filename): + def readFits(cls, filename: str): """Build an instance of this class from a file. Parameters @@ -506,7 +596,7 @@ def readFits(cls, filename): return cls.readFitsWithOptions(filename, None) @classmethod - def readFitsWithOptions(cls, filename, options): + def readFitsWithOptions(cls, filename: str, options: PropertyList): """Build an instance of this class with options. Parameters