From 2d8382fbaff5dcb2fcb45087cb6b392f0635f646 Mon Sep 17 00:00:00 2001 From: Lee Kelvin Date: Tue, 19 Mar 2024 12:55:10 -0700 Subject: [PATCH] Modify Stamp storage class for multi-arch elements --- python/lsst/meas/algorithms/stamps.py | 95 +++++++++++++++++++++------ 1 file changed, 75 insertions(+), 20 deletions(-) diff --git a/python/lsst/meas/algorithms/stamps.py b/python/lsst/meas/algorithms/stamps.py index 9352326bd..cc67785e3 100644 --- a/python/lsst/meas/algorithms/stamps.py +++ b/python/lsst/meas/algorithms/stamps.py @@ -36,6 +36,8 @@ from lsst.utils import doImport from lsst.utils.introspection import get_full_type_name +DEFAULT_ARCHIVE_ELEMENT_NAME = "ELEMENT" + def writeFits(filename, stamps, metadata, type_name, write_mask, write_variance, write_archive=False): """Write a single FITS file containing all stamps. @@ -65,25 +67,34 @@ def writeFits(filename, stamps, metadata, type_name, write_mask, write_variance, metadata["N_STAMPS"] = len(stamps) metadata["STAMPCLS"] = type_name # Record version number in case of future code changes - metadata["VERSION"] = 1 + metadata["VERSION"] = 2 # create primary HDU with global metadata fitsFile = Fits(filename, "w") fitsFile.createEmpty() # Store Persistables in an OutputArchive and write it if write_archive: + archive_ids = [] oa = OutputArchive() - archive_ids = [oa.put(stamp.archive_element) for stamp in stamps] - metadata["ARCHIVE_IDS"] = archive_ids + names = set() + for stamp in stamps: + archive_elements = stamp._getArchiveElements() + archive_ids.append({name: oa.put(persistable) for name, persistable in archive_elements.items()}) + names.update(archive_elements.keys()) fitsFile.writeMetadata(metadata) oa.writeFits(fitsFile) else: + archive_ids = [None] * len(stamps) fitsFile.writeMetadata(metadata) fitsFile.closeFile() # add all pixel data optionally writing mask and variance information - for i, stamp in enumerate(stamps): + for i, (stamp, stamp_archive_ids) in enumerate(zip(stamps, archive_ids)): metadata = PropertyList() # EXTVER should be 1-based, the index from enumerate is 0-based metadata.update({"EXTVER": i + 1, "EXTNAME": "IMAGE"}) + if stamp_archive_ids: + metadata.update(stamp_archive_ids) + for name in sorted(names): + metadata.add("ARCHIVE_ELEMENT", name) stamp.stamp_im.getImage().writeFits(filename, metadata=metadata, mode="a") if write_mask: metadata = PropertyList() @@ -131,8 +142,13 @@ def readFitsWithOptions(filename, stamp_factory, options): metadata = readMetadata(filename, hdu=0) nStamps = metadata["N_STAMPS"] has_archive = metadata["HAS_ARCHIVE"] + archive_names = None + archive_ids_v1 = None if has_archive: - archive_ids = metadata.getArray("ARCHIVE_IDS") + if metadata["VERSION"] < 2: + archive_ids_v1 = metadata.getArray("ARCHIVE_IDS") + else: + archive_names = metadata.getArray("ARCHIVE_ELEMENT") with Fits(filename, "r") as f: nExtensions = f.countHdus() # check if a bbox was provided @@ -165,6 +181,7 @@ def readFitsWithOptions(filename, stamp_factory, options): variance_dtype = np.dtype(np.float32) # Variance is always the same type. # We need to be careful because nExtensions includes the primary HDU. + archive_ids = {} for idx in range(nExtensions - 1): dtype = None md = readMetadata(filename, hdu=idx + 1) @@ -174,6 +191,8 @@ def readFitsWithOptions(filename, stamp_factory, options): dtype = variance_dtype else: dtype = default_dtype + if archive_names is not None: + archive_ids[idx] = {name: md[name] for name in archive_names if name in md.keys()} elif md["EXTNAME"] == "MASK": reader = MaskFitsReader(filename, hdu=idx + 1) elif md["EXTNAME"] == "ARCHIVE_INDEX": @@ -184,8 +203,9 @@ def readFitsWithOptions(filename, stamp_factory, options): continue else: raise ValueError(f"Unknown extension type: {md['EXTNAME']}") - stamp_parts.setdefault(md["EXTVER"], {})[md["EXTNAME"].lower()] = reader.read(dtype=dtype, - **kwargs) + stamp_parts.setdefault(md["EXTVER"], {})[md["EXTNAME"].lower()] = reader.read( + dtype=dtype, **kwargs + ) if len(stamp_parts) != nStamps: raise ValueError( f"Number of stamps read ({len(stamp_parts)}) does not agree with the " @@ -196,8 +216,12 @@ def readFitsWithOptions(filename, stamp_factory, options): for k in range(nStamps): # Need to increment by one since EXTVER starts at 1 maskedImage = masked_image_cls(**stamp_parts[k + 1]) - archive_element = archive.get(archive_ids[k]) if has_archive else None - stamps.append(stamp_factory(maskedImage, metadata, k, archive_element)) + if archive_ids_v1 is not None: + archive_elements = {DEFAULT_ARCHIVE_ELEMENT_NAME: archive.get(archive_ids_v1[k])} + elif archive_names is not None: + stamp_archive_ids = archive_ids.get(k, {}) + archive_elements = {name: archive.get(id) for name, id in stamp_archive_ids.items()} + stamps.append(stamp_factory(maskedImage, metadata, k, archive_elements)) return stamps, metadata @@ -213,7 +237,7 @@ class AbstractStamp(abc.ABC): @classmethod @abc.abstractmethod - def factory(cls, stamp_im, metadata, index, archive_element=None): + def factory(cls, stamp_im, metadata, index, archive_elements=None): """This method is needed to service the FITS reader. We need a standard interface to construct objects like this. Parameters needed to construct this object are passed in via a metadata dictionary and then @@ -228,15 +252,30 @@ def factory(cls, stamp_im, metadata, index, archive_element=None): needed by the constructor. idx : `int` Index into the lists in ``metadata`` - archive_element : `~lsst.afw.table.io.Persistable`, optional - Archive element (e.g. Transform or WCS) associated with this stamp. + archive_elements : `~collections.abc.Mapping`[ `str` , \ + `~lsst.afw.table.io.Persistable`], optional + Archive elements (e.g. Transform / WCS) associated with this stamp. Returns ------- stamp : `AbstractStamp` An instance of this class """ - raise NotImplementedError + raise NotImplementedError() + + @abc.abstractmethod + def _getMaskedImage(self): + """Return the image data.""" + raise NotImplementedError() + + @abc.abstractmethod + def _getArchiveElements(self): + """Return the archive elements. + + Keys should be upper case names that will be used directly as FITS + header keys. + """ + raise NotImplementedError() def _default_position(): @@ -265,7 +304,7 @@ class Stamp(AbstractStamp): position: SpherePoint | None = field(default_factory=_default_position) @classmethod - def factory(cls, stamp_im, metadata, index, archive_element=None): + def factory(cls, stamp_im, metadata, index, archive_elements=None): """This method is needed to service the FITS reader. We need a standard interface to construct objects like this. Parameters needed to construct this object are passed in via a metadata dictionary and then @@ -283,14 +322,23 @@ def factory(cls, stamp_im, metadata, index, archive_element=None): needed by the constructor. idx : `int` Index into the lists in ``metadata`` - archive_element : `~lsst.afw.table.io.Persistable`, optional - Archive element (e.g. Transform or WCS) associated with this stamp. + archive_elements : `~collections.abc.Mapping`[ `str` , \ + `~lsst.afw.table.io.Persistable`], optional + Archive elements (e.g. Transform / WCS) associated with this stamp. Returns ------- stamp : `Stamp` An instance of this class """ + if archive_elements: + try: + (archive_element,) = archive_elements.values() + except TypeError: + raise RuntimeError("Expected exactly one archive element.") + else: + archive_element = None + if "RA_DEG" in metadata and "DEC_DEG" in metadata: return cls( stamp_im=stamp_im, @@ -307,6 +355,12 @@ def factory(cls, stamp_im, metadata, index, archive_element=None): position=SpherePoint(Angle(np.nan), Angle(np.nan)), ) + def _getMaskedImage(self): + return self.stamp_im + + def _getArchiveElements(self): + return {DEFAULT_ARCHIVE_ELEMENT_NAME: self.archive_element} + class StampsBase(abc.ABC, Sequence): """Collection of stamps and associated metadata. @@ -437,17 +491,18 @@ def getMaskedImages(self): maskedImages : `list` [`~lsst.afw.image.MaskedImageF`] """ - return [stamp.stamp_im for stamp in self._stamps] + return [stamp._getMaskedImage() for stamp in self._stamps] def getArchiveElements(self): """Retrieve archive elements associated with each stamp. Returns ------- - archiveElements : - `list` [`~lsst.afw.table.io.Persistable`] + archiveElements : `list` [`dict`[ `str`, \ + `~lsst.afw.table.io.Persistable` ]] + A list of archive elements associated with each stamp. """ - return [stamp.archive_element for stamp in self._stamps] + return [stamp._getArchiveElements() for stamp in self._stamps] @property def metadata(self):