From 4bf7635a11069061a29fcc59a55455e51b9a4db2 Mon Sep 17 00:00:00 2001 From: andrew-uoa <142769327+andrew-uoa@users.noreply.github.com> Date: Fri, 30 Aug 2024 10:42:48 +1200 Subject: [PATCH] Implement bulk prefetching of datafile metadata, using batched GET requests (#498) * Implement bulk prefetching of datafile metadata, using batched GET requests * Simplify the use of validator dataclasses for representing GET request results * Ensure the keys used in the endpoint cache are hashable * Simplify the CacheKey class * Inline a class which isn't needed * Add initial tests for Overseer caching. Still need some work * Extract fixture to a common fixture location * Improve the overseer caching tests * Reinstate older dataclass for storing a GET response * Fix incorrect endpoint when prefetching datafiles * Ensure we use a valid key for prefetching datafiles (MyTardis requires a URI/ID instead of identifier string) * Enable optional (on-by-default) request caching in the MyTardis client * Restrict use of caching to GET requests * Set the requests_cache logging level to INFO by default * Fix logging of already-ingested datafile * Add more logging in the Overseer * Make the Overseer cache debug logging less verbose * Log the correct cache keys * Ensure we serialize directory the same way in both Datafile and IngestedDatafile * Don't cache response data with no objects * Remove caching of smaller response data in the Overseer (instead rely on use requests-cache in MyTardis client). Now we only use the Overseer caching for the prefetch data. * Add a log call * Avoid caching responses from datafile GET requests, as we don't use them and they add a lot of data to the cache * Don't log cache exclusions by default * Merge objects into the Overseer cache if there's a key collision, instead of raising an error. Also fix a logging inefficiency. * Log a warning if there's a cache key collision --- poetry.lock | 71 ++++++++- pyproject.toml | 1 + src/cli/cmd_ingest.py | 2 + src/ingestion_factory/factory.py | 19 ++- src/mytardis_client/endpoint_info.py | 8 +- src/mytardis_client/mt_rest.py | 79 +++++++--- src/mytardis_client/objects.py | 1 + src/mytardis_client/response_data.py | 145 +++++++++++++++---- src/overseers/overseer.py | 102 ++++++++++++- src/utils/container.py | 28 ++++ src/utils/types/type_helpers.py | 8 +- tests/conftest.py | 2 + tests/fixtures/fixtures_dataclasses.py | 90 +++++++++++- tests/fixtures/fixtures_ingestion_classes.py | 2 +- tests/test_mytardis_client_rest_factory.py | 30 +++- tests/test_overseer_cache.py | 100 +++++++++++++ tests/test_overseers.py | 12 +- 17 files changed, 626 insertions(+), 74 deletions(-) create mode 100644 src/utils/container.py create mode 100644 tests/test_overseer_cache.py diff --git a/poetry.lock b/poetry.lock index 6038c090..f45f5c47 100644 --- a/poetry.lock +++ b/poetry.lock @@ -386,6 +386,31 @@ dev = ["CacheControl[filecache,redis]", "black", "build", "cherrypy", "furo", "m filecache = ["filelock (>=3.8.0)"] redis = ["redis (>=2.10.5)"] +[[package]] +name = "cattrs" +version = "23.2.3" +description = "Composable complex class support for attrs and dataclasses." +optional = false +python-versions = ">=3.8" +files = [ + {file = "cattrs-23.2.3-py3-none-any.whl", hash = "sha256:0341994d94971052e9ee70662542699a3162ea1e0c62f7ce1b4a57f563685108"}, + {file = "cattrs-23.2.3.tar.gz", hash = "sha256:a934090d95abaa9e911dac357e3a8699e0b4b14f8529bcc7d2b1ad9d51672b9f"}, +] + +[package.dependencies] +attrs = ">=23.1.0" +exceptiongroup = {version = ">=1.1.1", markers = "python_version < \"3.11\""} +typing-extensions = {version = ">=4.1.0,<4.6.3 || >4.6.3", markers = "python_version < \"3.11\""} + +[package.extras] +bson = ["pymongo (>=4.4.0)"] +cbor2 = ["cbor2 (>=5.4.6)"] +msgpack = ["msgpack (>=1.0.5)"] +orjson = ["orjson (>=3.9.2)"] +pyyaml = ["pyyaml (>=6.0)"] +tomlkit = ["tomlkit (>=0.11.8)"] +ujson = ["ujson (>=5.7.0)"] + [[package]] name = "certifi" version = "2024.7.4" @@ -2785,6 +2810,36 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "requests-cache" +version = "1.2.1" +description = "A persistent cache for python requests" +optional = false +python-versions = ">=3.8" +files = [ + {file = "requests_cache-1.2.1-py3-none-any.whl", hash = "sha256:1285151cddf5331067baa82598afe2d47c7495a1334bfe7a7d329b43e9fd3603"}, + {file = "requests_cache-1.2.1.tar.gz", hash = "sha256:68abc986fdc5b8d0911318fbb5f7c80eebcd4d01bfacc6685ecf8876052511d1"}, +] + +[package.dependencies] +attrs = ">=21.2" +cattrs = ">=22.2" +platformdirs = ">=2.5" +requests = ">=2.22" +url-normalize = ">=1.4" +urllib3 = ">=1.25.5" + +[package.extras] +all = ["boto3 (>=1.15)", "botocore (>=1.18)", "itsdangerous (>=2.0)", "pymongo (>=3)", "pyyaml (>=6.0.1)", "redis (>=3)", "ujson (>=5.4)"] +bson = ["bson (>=0.5)"] +docs = ["furo (>=2023.3,<2024.0)", "linkify-it-py (>=2.0,<3.0)", "myst-parser (>=1.0,<2.0)", "sphinx (>=5.0.2,<6.0.0)", "sphinx-autodoc-typehints (>=1.19)", "sphinx-automodapi (>=0.14)", "sphinx-copybutton (>=0.5)", "sphinx-design (>=0.2)", "sphinx-notfound-page (>=0.8)", "sphinxcontrib-apidoc (>=0.3)", "sphinxext-opengraph (>=0.9)"] +dynamodb = ["boto3 (>=1.15)", "botocore (>=1.18)"] +json = ["ujson (>=5.4)"] +mongodb = ["pymongo (>=3)"] +redis = ["redis (>=3)"] +security = ["itsdangerous (>=2.0)"] +yaml = ["pyyaml (>=6.0.1)"] + [[package]] name = "requests-toolbelt" version = "1.0.0" @@ -3584,6 +3639,20 @@ files = [ [package.extras] test = ["coverage", "pytest", "pytest-cov"] +[[package]] +name = "url-normalize" +version = "1.4.3" +description = "URL normalization for Python" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +files = [ + {file = "url-normalize-1.4.3.tar.gz", hash = "sha256:d23d3a070ac52a67b83a1c59a0e68f8608d1cd538783b401bc9de2c0fac999b2"}, + {file = "url_normalize-1.4.3-py2.py3-none-any.whl", hash = "sha256:ec3c301f04e5bb676d333a7fa162fa977ad2ca04b7e652bfc9fac4e405728eed"}, +] + +[package.dependencies] +six = "*" + [[package]] name = "urllib3" version = "2.2.2" @@ -3766,4 +3835,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "c7fc4536719485e4d31fb389508f326ed9eedd89fef536e04f7687b805a1e170" +content-hash = "a5cccbd962a128030355f4e9006ce44b9fbfaebff322a95f21cf219c36b613cb" diff --git a/pyproject.toml b/pyproject.toml index 0a39308c..ca51142d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ pytz = "^2024.1" types-pytz = "^2024.1.0.20240417" python-slugify = "^8.0.1" pydantic-settings = "^2.3.4" +requests-cache = "^1.2.1" validators = "^0.33.0" typer = "^0.12.0" rocrate = "^0.10.0" diff --git a/src/cli/cmd_ingest.py b/src/cli/cmd_ingest.py index 77bb42b6..3ca5da9d 100644 --- a/src/cli/cmd_ingest.py +++ b/src/cli/cmd_ingest.py @@ -87,6 +87,8 @@ def upload( "Manifest directory is empty. Extract data into a manifest using 'extract' command." ) + logging.info("Loading metadata manifest from %s", manifest_dir) + manifest = IngestionManifest.deserialize(manifest_dir) logging.info("Successfully loaded metadata manifest from %s", manifest_dir) diff --git a/src/ingestion_factory/factory.py b/src/ingestion_factory/factory.py index 6d4db1ae..4320c16e 100644 --- a/src/ingestion_factory/factory.py +++ b/src/ingestion_factory/factory.py @@ -207,6 +207,8 @@ def ingest_datafiles( result = IngestionResult() datafiles: list[Datafile] = [] + datafiles_prefetched: dict[URI, bool] = {} + for raw_datafile in raw_datafiles: refined_datafile = self.smelter.smelt_datafile(raw_datafile) if not refined_datafile: @@ -217,6 +219,18 @@ def ingest_datafiles( if not datafile: result.error.append(refined_datafile.display_name) continue + + if not datafiles_prefetched.get(datafile.dataset): + num_objects = self._overseer.prefetch( + "/dataset_file", query_params={"dataset": datafile.dataset} + ) + datafiles_prefetched[datafile.dataset] = True + logging.info( + "Prefetched %d datafiles for dataset %s", + num_objects, + datafile.dataset, + ) + # Add a replica to represent the copy transferred by the Conveyor. datafile.replicas.append(self.conveyor.create_replica(datafile)) @@ -226,8 +240,9 @@ def ingest_datafiles( ) if len(matching_datafiles) > 0: logging.info( - 'Already ingested datafile "%s". Skipping datafile ingestion.', - datafile.directory, + 'Already ingested datafile "%s" as %s. Skipping datafile ingestion.', + datafile.filepath, + matching_datafiles[0].resource_uri, ) result.skipped.append((datafile.display_name, None)) continue diff --git a/src/mytardis_client/endpoint_info.py b/src/mytardis_client/endpoint_info.py index 5c701f55..9594815e 100644 --- a/src/mytardis_client/endpoint_info.py +++ b/src/mytardis_client/endpoint_info.py @@ -17,7 +17,7 @@ Institution, Instrument, MyTardisIntrospection, - MyTardisResourceBase, + MyTardisObjectData, ProjectParameterSet, StorageBox, ) @@ -26,11 +26,7 @@ class GetRequestProperties(BaseModel): """Definition of behaviour/structure for a GET request to a MyTardis endpoint.""" - # Note: it would be useful here to store the dataclass type for the response, so that - # the response can be validated/deserialized without the requester needing - # to know the correct type. But the dataclasses are currently defined outside the - # mytardis_client module, and this module should ideally be self-contained. - response_obj_type: type[MyTardisResourceBase] + response_obj_type: type[MyTardisObjectData] class PostRequestProperties(BaseModel): diff --git a/src/mytardis_client/mt_rest.py b/src/mytardis_client/mt_rest.py index 36e6a26d..bc0864d7 100644 --- a/src/mytardis_client/mt_rest.py +++ b/src/mytardis_client/mt_rest.py @@ -7,12 +7,14 @@ import logging from copy import deepcopy +from datetime import timedelta from typing import Any, Callable, Dict, Generic, Literal, Optional, TypeVar -from urllib.parse import urljoin +from urllib.parse import urljoin, urlparse import requests from pydantic import BaseModel, ValidationError -from requests import ConnectTimeout, ReadTimeout, RequestException, Response +from requests import ConnectTimeout, ReadTimeout, RequestException, Response, Session +from requests_cache import CachedSession from tenacity import ( before_sleep_log, retry, @@ -25,13 +27,17 @@ from src.mytardis_client.common_types import HttpRequestMethod from src.mytardis_client.endpoint_info import get_endpoint_info from src.mytardis_client.endpoints import URI, MyTardisEndpoint -from src.mytardis_client.response_data import MyTardisResource +from src.mytardis_client.response_data import MyTardisObjectData # Defines the valid values for the MyTardis API version MyTardisApiVersion = Literal["v1"] logger = logging.getLogger(__name__) +# requests_cache is quite verbose in DEBUG - can crowd out other log messages +caching_logger = logging.getLogger("requests_cache") +caching_logger.setLevel(logging.INFO) + def make_api_stub(version: MyTardisApiVersion) -> str: """Creates a stub for the MyTardis API URL @@ -76,23 +82,14 @@ class GetResponseMeta(BaseModel): previous: Optional[str] -MyTardisObjectData = TypeVar("MyTardisObjectData", bound=BaseModel) +T = TypeVar("T", bound=BaseModel) -class GetResponse(BaseModel, Generic[MyTardisObjectData]): - """A Pydantic model to handle the response from a GET request to the MyTardis API""" +class GetResponse(BaseModel, Generic[T]): + """Data model for a response from a GET request to the MyTardis API""" meta: GetResponseMeta - objects: list[MyTardisObjectData] - - -class Ingested(BaseModel, Generic[MyTardisObjectData]): - """A Pydantic model to store the data of an ingested object, i.e. the response from a GET - request to the MyTardis API, along with the URI. - """ - - obj: MyTardisObjectData - resource_uri: URI + objects: list[T] def sanitize_params(params: dict[str, Any]) -> dict[str, Any]: @@ -132,6 +129,29 @@ def replace_uri_with_id(value: Any) -> Any: return updated_params +caching_logger = logging.getLogger("requests_cache") + + +def make_endpoint_filter( + endpoints: tuple[MyTardisEndpoint], +) -> Callable[[requests.Response], bool]: + """Create a filter predicate to exclude responses from certain endpoints from + being cached. + """ + + def retain_response(response: requests.Response) -> bool: + path = urlparse(response.url).path.rstrip("/") + for endpoint in endpoints: + if path.endswith(endpoint): + caching_logger.debug( + "Request cache filter excluded response from: %s", response.url + ) + return False + return True + + return retain_response + + class MyTardisRESTFactory: """Class to interact with MyTardis by calling the REST API @@ -156,6 +176,7 @@ def __init__( auth: AuthConfig, connection: ConnectionConfig, request_timeout: int = 30, + use_cache: bool = True, ) -> None: """MyTardisRESTFactory initialisation using a configuration dictionary. @@ -182,7 +203,18 @@ def __init__( self._url_base = urljoin(self._hostname, self._api_stub) self.user_agent = f"{self.user_agent_name}/2.0 ({self.user_agent_url})" - self._session = requests.Session() + + # Don't cache datafile responses as they are voluminous and not likely to be reused. + self._session = ( + CachedSession( + backend="memory", + expire_after=timedelta(hours=1), + allowable_methods=("GET",), + filter_fn=make_endpoint_filter(("/dataset_file",)), + ) + if use_cache + else Session() + ) self._request_timeout = request_timeout @@ -281,7 +313,7 @@ def get( endpoint: MyTardisEndpoint, query_params: Optional[dict[str, Any]] = None, meta_params: Optional[GetRequestMetaParams] = None, - ) -> tuple[list[MyTardisResource], GetResponseMeta]: + ) -> tuple[list[MyTardisObjectData], GetResponseMeta]: """Submit a GET request to the MyTardis API and return the response as a list of objects. Note that the response is paginated, so the function may not return all objects matching @@ -309,7 +341,7 @@ def get( response_meta = GetResponseMeta.model_validate(response_json["meta"]) - objects: list[MyTardisResource] = [] + objects: list[MyTardisObjectData] = [] response_objects = response_json.get("objects") if response_objects is None: @@ -336,7 +368,7 @@ def get_all( endpoint: MyTardisEndpoint, query_params: Optional[dict[str, Any]] = None, batch_size: int = 500, - ) -> tuple[list[MyTardisResource], int]: + ) -> tuple[list[MyTardisObjectData], int]: """Get all objects of the given type that match 'query_params'. Sends repeated GET requests to the MyTardis API until all objects have been retrieved. @@ -344,7 +376,7 @@ def get_all( each request """ - objects: list[MyTardisResource] = [] + objects: list[MyTardisObjectData] = [] while True: request_meta = GetRequestMetaParams(limit=batch_size, offset=len(objects)) @@ -364,3 +396,8 @@ def get_all( break return objects, response_meta.total_count + + def clear_cache(self) -> None: + """Clear the cache of the requests session""" + if isinstance(self._session, CachedSession): + self._session.cache.clear() # type: ignore[no-untyped-call] diff --git a/src/mytardis_client/objects.py b/src/mytardis_client/objects.py index ad43e700..fe206db1 100644 --- a/src/mytardis_client/objects.py +++ b/src/mytardis_client/objects.py @@ -38,6 +38,7 @@ class MyTardisObject(str, Enum): REPLICA = "replica" SCHEMA = "schema" STORAGE_BOX = "storagebox" + STORAGE_BOX_OPTION = "storageboxoption" USER = "user" diff --git a/src/mytardis_client/response_data.py b/src/mytardis_client/response_data.py index 5be13416..ca906137 100644 --- a/src/mytardis_client/response_data.py +++ b/src/mytardis_client/response_data.py @@ -1,9 +1,10 @@ """Dataclasses for validating/storing MyTardis API response data.""" +from abc import abstractmethod from pathlib import Path -from typing import Any, Optional, Protocol +from typing import Any, Optional -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, field_serializer, model_validator from typing_extensions import Self from src.mytardis_client.common_types import ( @@ -16,54 +17,68 @@ from src.mytardis_client.objects import MyTardisObject -# pylint: disable=too-few-public-methods -class MyTardisResource(Protocol): - """Protocol for MyTardis resources.""" +class MyTardisObjectData(BaseModel): + """Base class for object data retrieved from MyTardis. Defines minimal fields, e.g. URI""" - id: int - resource_uri: URI - - -class MyTardisResourceBase(BaseModel): - """Base class for data retrieved from MyTardis, associated with an ingested object.""" + @property + @abstractmethod + def mytardis_type(self) -> MyTardisObject: + """The type of the MyTardis object.""" + raise NotImplementedError("mytardis_type must be implemented") id: int resource_uri: URI -class Group(MyTardisResourceBase): +class Group(MyTardisObjectData): """Metadata associated with a group in MyTardis.""" + @property + def mytardis_type(self) -> MyTardisObject: + return MyTardisObject.GROUP + name: str -class Facility(MyTardisResourceBase): +class Facility(MyTardisObjectData): """Metadata associated with a facility in MyTardis.""" + @property + def mytardis_type(self) -> MyTardisObject: + return MyTardisObject.FACILITY + created_time: ISODateTime manager_group: Group modified_time: ISODateTime name: str -class Institution(MyTardisResourceBase): +class Institution(MyTardisObjectData): """Metadata associated with an institution in MyTardis.""" + @property + def mytardis_type(self) -> MyTardisObject: + return MyTardisObject.INSTITUTION + aliases: Optional[list[str]] identifiers: list[str] name: str -class Instrument(MyTardisResourceBase): +class Instrument(MyTardisObjectData): """Metadata associated with an instrument in MyTardis.""" + @property + def mytardis_type(self) -> MyTardisObject: + return MyTardisObject.INSTRUMENT + created_time: ISODateTime facility: Facility modified_time: ISODateTime name: str -class MyTardisIntrospection(MyTardisResourceBase): +class MyTardisIntrospection(MyTardisObjectData): """MyTardis introspection data (the configuration of the MyTardis instance). NOTE: this class relies on data from the MyTardis introspection API and @@ -73,6 +88,10 @@ class MyTardisIntrospection(MyTardisResourceBase): model_config = ConfigDict(use_enum_values=False) + @property + def mytardis_type(self) -> MyTardisObject: + return MyTardisObject.INTROSPECTION + data_classification_enabled: Optional[bool] identifiers_enabled: bool objects_with_ids: list[MyTardisObject] = Field( @@ -119,9 +138,13 @@ def validate_consistency(self) -> Self: return self -class ParameterName(MyTardisResourceBase): +class ParameterName(MyTardisObjectData): """Schema parameter information""" + @property + def mytardis_type(self) -> MyTardisObject: + return MyTardisObject.PARAMETER_NAME + full_name: str immutable: bool is_searchable: bool @@ -131,9 +154,13 @@ class ParameterName(MyTardisResourceBase): units: str -class Replica(MyTardisResourceBase): +class Replica(MyTardisObjectData): """Metadata associated with a Datafile replica in MyTardis.""" + @property + def mytardis_type(self) -> MyTardisObject: + return MyTardisObject.REPLICA + created_time: ISODateTime datafile: URI last_verified_time: Optional[ISODateTime] @@ -142,9 +169,13 @@ class Replica(MyTardisResourceBase): verified: bool -class Schema(MyTardisResourceBase): +class Schema(MyTardisObjectData): """Metadata associated with a metadata schema in MyTardis.""" + @property + def mytardis_type(self) -> MyTardisObject: + return MyTardisObject.SCHEMA + hidden: bool immutable: bool name: str @@ -152,18 +183,26 @@ class Schema(MyTardisResourceBase): parameter_names: list[ParameterName] -class StorageBoxOption(MyTardisResourceBase): +class StorageBoxOption(MyTardisObjectData): """Data associated with a storage box option in MyTardis.""" + @property + def mytardis_type(self) -> MyTardisObject: + return MyTardisObject.STORAGE_BOX_OPTION + key: str storage_box: URI value: str value_type: str -class StorageBox(MyTardisResourceBase): +class StorageBox(MyTardisObjectData): """Metadata associated with a storage box in MyTardis.""" + @property + def mytardis_type(self) -> MyTardisObject: + return MyTardisObject.STORAGE_BOX + attributes: list[str] description: str django_storage_class: str @@ -173,9 +212,13 @@ class StorageBox(MyTardisResourceBase): status: str -class User(MyTardisResourceBase): +class User(MyTardisObjectData): """Dataa associated with a user in MyTardis.""" + @property + def mytardis_type(self) -> MyTardisObject: + return MyTardisObject.USER + email: Optional[str] first_name: Optional[str] groups: list[Group] @@ -183,25 +226,45 @@ class User(MyTardisResourceBase): username: str -class ProjectParameterSet(MyTardisResourceBase): +class ProjectParameterSet(MyTardisObjectData): """Metadata associated with a project parameter set in MyTardis.""" + @property + def mytardis_type(self) -> MyTardisObject: + return MyTardisObject.PROJECT_PARAMETER_SET + -class ExperimentParameterSet(MyTardisResourceBase): +class ExperimentParameterSet(MyTardisObjectData): """Metadata associated with an experiment parameter set in MyTardis.""" + @property + def mytardis_type(self) -> MyTardisObject: + return MyTardisObject.EXPERIMENT_PARAMETER_SET -class DatasetParameterSet(MyTardisResourceBase): + +class DatasetParameterSet(MyTardisObjectData): """Metadata associated with a dataset parameter set in MyTardis.""" + @property + def mytardis_type(self) -> MyTardisObject: + return MyTardisObject.DATASET_PARAMETER_SET + -class DatafileParameterSet(MyTardisResourceBase): +class DatafileParameterSet(MyTardisObjectData): """Metadata associated with a datafile parameter set in MyTardis.""" + @property + def mytardis_type(self) -> MyTardisObject: + return MyTardisObject.DATAFILE_PARAMETER_SET + -class IngestedProject(MyTardisResourceBase): +class IngestedProject(MyTardisObjectData): """Metadata associated with a project that has been ingested into MyTardis.""" + @property + def mytardis_type(self) -> MyTardisObject: + return MyTardisObject.PROJECT + classification: DataClassification description: str identifiers: Optional[list[str]] @@ -211,9 +274,13 @@ class IngestedProject(MyTardisResourceBase): principal_investigator: str -class IngestedExperiment(MyTardisResourceBase): +class IngestedExperiment(MyTardisObjectData): """Metadata associated with an experiment that has been ingested into MyTardis.""" + @property + def mytardis_type(self) -> MyTardisObject: + return MyTardisObject.EXPERIMENT + classification: int description: str identifiers: Optional[list[str]] @@ -222,9 +289,13 @@ class IngestedExperiment(MyTardisResourceBase): title: str -class IngestedDataset(MyTardisResourceBase): +class IngestedDataset(MyTardisObjectData): """Metadata associated with a dataset that has been ingested into MyTardis.""" + @property + def mytardis_type(self) -> MyTardisObject: + return MyTardisObject.DATASET + classification: DataClassification created_time: ISODateTime description: str @@ -238,9 +309,13 @@ class IngestedDataset(MyTardisResourceBase): public_access: bool -class IngestedDatafile(MyTardisResourceBase): +class IngestedDatafile(MyTardisObjectData): """Metadata associated with a datafile that has been ingested into MyTardis.""" + @property + def mytardis_type(self) -> MyTardisObject: + return MyTardisObject.DATAFILE + created_time: Optional[ISODateTime] dataset: URI deleted: bool @@ -256,3 +331,13 @@ class IngestedDatafile(MyTardisResourceBase): replicas: list[Replica] size: int version: int + + @field_serializer("directory") + def dir_as_posix_path(self, directory: Optional[Path]) -> Optional[str]: + """Ensures the directory is always serialized as a posix path, or `None` if not set. + + Note: this is mainly for parity with `Datafile`, as otherwise we fail to match + corresponding pre-ingest/ingested datafiles, because one stores directory as + a string, and the other as a Path object. + """ + return directory.as_posix() if directory else None diff --git a/src/overseers/overseer.py b/src/overseers/overseer.py index 6784949a..9eb67f73 100644 --- a/src/overseers/overseer.py +++ b/src/overseers/overseer.py @@ -8,10 +8,13 @@ from typeguard import check_type +from src.mytardis_client.endpoint_info import get_endpoint_info from src.mytardis_client.endpoints import URI, MyTardisEndpoint from src.mytardis_client.mt_rest import MyTardisRESTFactory from src.mytardis_client.objects import MyTardisObject, get_type_info -from src.mytardis_client.response_data import MyTardisIntrospection, MyTardisResource +from src.mytardis_client.response_data import MyTardisIntrospection, MyTardisObjectData +from src.utils.container import lazyjoin +from src.utils.types.type_helpers import is_list_of logger = logging.getLogger(__name__) @@ -64,6 +67,53 @@ def extract_values_for_matching( return match_keys +class MyTardisEndpointCache: + """A cache for URIs and objects from a specific MyTardis endpoint""" + + def __init__(self, endpoint: MyTardisEndpoint) -> None: + self.endpoint = endpoint + self._objects: list[list[MyTardisObjectData]] = [] + self._index: dict[tuple[tuple[str, Any], ...], int] = {} + + def _to_hashable(self, keys: dict[str, Any]) -> tuple[tuple[str, Any], ...]: + return tuple(keys.items()) + + def emplace(self, keys: dict[str, Any], objects: list[MyTardisObjectData]) -> None: + """Add objects to the cache""" + hashable_keys = self._to_hashable(keys) + + if index := self._index.get(hashable_keys): + logger.warning( + "Cache entry already exists for keys: %s. Merging values.", + hashable_keys, + ) + self._objects[index].extend(objects) + else: + self._index[hashable_keys] = len(self._objects) + self._objects.append(objects) + + logger.debug( + "Cache entry added. key: %s, objects: %s", + hashable_keys, + lazyjoin(", ", (obj.resource_uri for obj in objects)), + ) + + def get(self, keys: dict[str, Any]) -> list[MyTardisObjectData] | None: + """Get objects from the cache""" + hashable_keys = self._to_hashable(keys) + object_index = self._index.get(hashable_keys) + if object_index is not None: + logger.debug( + "Cache hit. keys: %s, objects: %s", + hashable_keys, + [obj.resource_uri for obj in self._objects[object_index]], + ) + return self._objects[object_index] + + logger.debug("Cache miss. keys: %s", hashable_keys) + return None + + class Overseer: """The Overseer class inspects MyTardis @@ -100,6 +150,8 @@ def __init__( """ self.rest_factory = rest_factory + self._cache: dict[MyTardisEndpoint, MyTardisEndpointCache] = {} + @property def mytardis_setup(self) -> MyTardisIntrospection: """Getter for mytardis_setup. Sends API request on first call and caches the result""" @@ -147,11 +199,15 @@ def _get_matches_from_mytardis( self, object_type: MyTardisObject, query_params: dict[str, str], - ) -> list[MyTardisResource]: + ) -> list[MyTardisObjectData]: """Get objects from MyTardis that match the given query parameters""" endpoint = get_default_endpoint(object_type) + if self._cache.get(endpoint): + if objects := self._cache[endpoint].get(query_params): + return objects + try: objects, _ = self.rest_factory.get(endpoint, query_params) except Exception as error: @@ -162,11 +218,51 @@ def _get_matches_from_mytardis( return objects + def prefetch( + self, + endpoint: MyTardisEndpoint, + query_params: dict[str, Any], + ) -> int: + """Populate the cache for the given endpoint. + + Returns the number of objects prefetched. + """ + + endpoint_info = get_endpoint_info(endpoint) + if endpoint_info.methods.GET is None: + raise ValueError(f"Endpoint {endpoint} does not support GET requests") + + logger.info(f"Prefetching from {endpoint} with query params {query_params}") + + if self._cache.get(endpoint) is None: + self._cache[endpoint] = MyTardisEndpointCache(endpoint) + + objects, _ = self.rest_factory.get_all(endpoint, query_params) + + # Need to check this to ensure model_dump() is available. Can we avoid somehow? + if not is_list_of(objects, endpoint_info.methods.GET.response_obj_type): + raise ValueError( + f"Expected GET request to yield list of " + f"{endpoint_info.methods.GET.response_obj_type}, " + f"but got {objects}" + ) + + for obj in objects: + mt_type = obj.mytardis_type + matchers = self.generate_object_matchers(mt_type, obj.model_dump()) + + for keys in matchers: + self._cache[endpoint].emplace(keys, [obj]) + + logger.info(f"Prefetched {len(objects)} objects from {endpoint}") + + return len(objects) + def get_matching_objects( self, object_type: MyTardisObject, object_data: dict[str, str], - ) -> list[MyTardisResource]: + ) -> list[MyTardisObjectData]: """Retrieve objects from MyTardis with field values matching the ones in "field_values" The function extracts the type-dependent match keys from 'object_data' and uses them to diff --git a/src/utils/container.py b/src/utils/container.py new file mode 100644 index 00000000..d109edda --- /dev/null +++ b/src/utils/container.py @@ -0,0 +1,28 @@ +"""Utility functions for working with containers.""" + +from typing import Any, Iterable + +from src.utils.types.type_helpers import Stringable + + +def subdict(full_dict: dict[str, Any], keys: list[str]) -> dict[str, Any]: + """Return a sub-dictionary of the full dictionary, containing only values + specified in 'keys' + """ + return {key: full_dict[key] for key in keys} + + +# pylint: disable=invalid-name +class lazyjoin: + """Class used to lazily join strings together. + + For example, to log a list of items only if the log level is set to DEBUG: + logger.debug(lazyjoin(", ", items)) + """ + + def __init__(self, s: str, items: Iterable[Stringable]): + self.s = s + self.items = items + + def __str__(self) -> str: + return self.s.join((str(item) for item in self.items)) diff --git a/src/utils/types/type_helpers.py b/src/utils/types/type_helpers.py index c1f8574a..ae9aee51 100644 --- a/src/utils/types/type_helpers.py +++ b/src/utils/types/type_helpers.py @@ -1,6 +1,6 @@ """Helpers for working with types and type-checking.""" -from typing import Any, TypeGuard, TypeVar +from typing import Any, Protocol, TypeGuard, TypeVar T = TypeVar("T") @@ -14,3 +14,9 @@ def is_list_of(obj: Any, query_type: type[T]) -> TypeGuard[list[T]]: """Check if an object is a list with elements of a certain type.""" return isinstance(obj, list) and all(isinstance(entry, query_type) for entry in obj) + + +class Stringable(Protocol): + """Protocol for objects that can be converted to a string.""" + + def __str__(self) -> str: ... diff --git a/tests/conftest.py b/tests/conftest.py index c89fb7a3..c43b50ab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -158,6 +158,8 @@ def datadir(tmpdir: str, request: FixtureRequest) -> Path: experiment = dcls.experiment dataset = dcls.dataset datafile = dcls.datafile +make_datafile = dcls.make_datafile +make_ingested_datafile = dcls.make_ingested_datafile # ========================================= # diff --git a/tests/fixtures/fixtures_dataclasses.py b/tests/fixtures/fixtures_dataclasses.py index 445c2180..510993f8 100644 --- a/tests/fixtures/fixtures_dataclasses.py +++ b/tests/fixtures/fixtures_dataclasses.py @@ -3,8 +3,9 @@ from datetime import datetime from pathlib import Path -from typing import Any, Dict, List +from typing import Any, Dict, List, Protocol, TypeVar +from pydantic import BaseModel from pytest import fixture from src.blueprints.common_models import GroupACL, Parameter, ParameterSet, UserACL @@ -19,6 +20,7 @@ from src.blueprints.project import Project, RawProject, RefinedProject from src.mytardis_client.common_types import DataClassification from src.mytardis_client.endpoints import URI +from src.mytardis_client.response_data import IngestedDatafile @fixture @@ -439,3 +441,89 @@ def datafile( datafile_replica, ], ) + + +T_co = TypeVar("T_co", bound=BaseModel, covariant=True) + + +class TestModelFactory(Protocol[T_co]): + """Protocol for a factory function that creates pydantic models to be used in tests. + + Used in place of Callable[] as it is difficult to declare a Callable taking **kwargs + """ + + def __call__(self, **kwargs: Any) -> T_co: ... + + +_DEFAULT_DATACLASS_ARGS: dict[type, dict[str, Any]] = { + Datafile: { + "filename": "test_file.txt", + "directory": Path("path/to/datafile"), + "md5sum": "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", + "mimetype": "text/plain", + "size": 1024, + "users": None, + "groups": None, + "data_status": None, + "replicas": [], + "parameter_sets": None, + "dataset": URI("/api/v1/dataset/1/"), + }, + IngestedDatafile: { + "resource_uri": URI("/api/v1/dataset_file/1/"), + "id": 1, + "dataset": URI("/api/v1/dataset/1/"), + "deleted": False, + "directory": Path("path/to/df_1"), + "filename": "df_1.txt", + "md5sum": "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", + "mimetype": "text/plain", + "parameter_sets": [], + "public_access": False, + "replicas": [], + "size": 1024, + "version": 1, + "created_time": None, + "deleted_time": None, + "modification_time": None, + "identifiers": ["dataset-id-1"], + }, +} + + +def make_dataclass_factory(dc_type: type[T_co]) -> TestModelFactory[T_co]: + """Factory function for creating factories for specific dataclasses. + + Returns a function that creates instances of the specified dataclass with default + argument values. These values can be overridden by passing keyword arguments to the + factory function. This allows testers to easily create instances of dataclasses, + while only specifying the values that are relevant to the test. + + Args: + dc_type: The dataclass type for which to create a factory function. + + Returns: + A factory function that creates instances of the specified dataclass 'dc_type'. + """ + + default_args = _DEFAULT_DATACLASS_ARGS[dc_type] + + def _make_dataclass(**kwargs: Any) -> T_co: + """Create an instance of the dataclass with the specified keyword arguments and + and default values (kwargs are given priority over default values). + """ + + result = {**default_args, **kwargs} + return dc_type.model_validate(result) + + return _make_dataclass + + +@fixture +def make_datafile() -> TestModelFactory[Datafile]: + return make_dataclass_factory(Datafile) + + +@fixture +def make_ingested_datafile() -> TestModelFactory[IngestedDatafile]: + return make_dataclass_factory(IngestedDatafile) diff --git a/tests/fixtures/fixtures_ingestion_classes.py b/tests/fixtures/fixtures_ingestion_classes.py index f33d6492..c194c16a 100644 --- a/tests/fixtures/fixtures_ingestion_classes.py +++ b/tests/fixtures/fixtures_ingestion_classes.py @@ -35,7 +35,7 @@ def rest_factory( auth: AuthConfig, connection: ConnectionConfig, ) -> MyTardisRESTFactory: - return MyTardisRESTFactory(auth, connection) + return MyTardisRESTFactory(auth, connection, use_cache=False) @fixture diff --git a/tests/test_mytardis_client_rest_factory.py b/tests/test_mytardis_client_rest_factory.py index 18435dd8..8d081929 100644 --- a/tests/test_mytardis_client_rest_factory.py +++ b/tests/test_mytardis_client_rest_factory.py @@ -17,11 +17,12 @@ from src.blueprints.datafile import Datafile from src.config.config import AuthConfig, ConnectionConfig -from src.mytardis_client.endpoints import URI +from src.mytardis_client.endpoints import URI, MyTardisEndpoint from src.mytardis_client.mt_rest import ( GetRequestMetaParams, GetResponseMeta, MyTardisRESTFactory, + make_endpoint_filter, sanitize_params, ) from src.mytardis_client.response_data import IngestedDatafile @@ -306,3 +307,30 @@ def test_mytardis_client_get_params_are_sanitized( query_params={"dataset": URI("/api/v1/dataset/0/")}, meta_params=GetRequestMetaParams(limit=1, offset=0), ) + + +@pytest.mark.parametrize( + "endpoints,url,expected_output", + [ + pytest.param(("/dataset_file",), "https://example.com/dataset_file", False), + pytest.param(("/dataset",), "https://example.com/dataset_file/", True), + pytest.param(("/dataset",), "https://example.com/dataset/", False), + pytest.param(("/dataset",), "https://example.com/dataset", False), + pytest.param( + ( + "/dataset", + "/dataset_file", + ), + "https://example.com/dataset", + False, + ), + ], +) +def test_make_endpoint_filter( + endpoints: tuple[MyTardisEndpoint], url: str, expected_output: bool +) -> None: + + response = MagicMock() + response.url = url + + assert make_endpoint_filter(endpoints)(response) == expected_output diff --git a/tests/test_overseer_cache.py b/tests/test_overseer_cache.py new file mode 100644 index 00000000..131f14d8 --- /dev/null +++ b/tests/test_overseer_cache.py @@ -0,0 +1,100 @@ +""" +Tests for caching in the Overseer. +""" + +# pylint: disable=missing-function-docstring, missing-class-docstring + +from typing import Any + +import responses +from responses import matchers + +from src.config.config import AuthConfig, ConnectionConfig +from src.mytardis_client.endpoints import URI +from src.mytardis_client.mt_rest import ( + GetResponse, + GetResponseMeta, + MyTardisRESTFactory, +) +from src.mytardis_client.objects import MyTardisObject, get_type_info +from src.mytardis_client.response_data import IngestedDatafile, MyTardisObjectData +from src.overseers.overseer import MyTardisEndpointCache, Overseer +from src.utils.container import subdict +from tests.fixtures.fixtures_dataclasses import TestModelFactory + + +def test_overseer_endpoint_cache( + make_ingested_datafile: TestModelFactory[IngestedDatafile], +) -> None: + + df_cache = MyTardisEndpointCache("/dataset_file") + + df_1 = make_ingested_datafile() + + objects: list[MyTardisObjectData] = [df_1] + + object_dict = df_1.model_dump() + keys = subdict(object_dict, ["filename", "directory", "dataset"]) + + assert df_cache.get(keys) is None + + df_cache.emplace(keys, objects) + + assert df_cache.get(keys) == objects + + +@responses.activate +def test_overseer_prefetch( + auth: AuthConfig, + connection: ConnectionConfig, + make_ingested_datafile: TestModelFactory[IngestedDatafile], + introspection_response: dict[str, Any], +) -> None: + + mt_client = MyTardisRESTFactory(auth, connection) + overseer = Overseer(mt_client) + + total_count = 100 + + ingested_datafiles = [ + make_ingested_datafile( + filename=f"ingested-file-{i}.txt", dataset=URI("/api/v1/dataset/1/") + ) + for i in range(0, total_count) + ] + + responses.add( + responses.GET, + mt_client.compose_url("/dataset_file"), + match=[ + matchers.query_param_matcher( + {"dataset": "dataset-id-1", "limit": "500", "offset": "0"} + ), + ], + status=200, + json=GetResponse( + objects=ingested_datafiles, + meta=GetResponseMeta( + limit=500, offset=0, total_count=1, next=None, previous=None + ), + ).model_dump(mode="json"), + ) + + responses.add( + responses.GET, + mt_client.compose_url("/introspection"), + status=200, + json=introspection_response, + ) + + num_objects = overseer.prefetch("/dataset_file", {"dataset": "dataset-id-1"}) + + assert num_objects == total_count + + match_fields = get_type_info(MyTardisObject.DATAFILE).match_fields + + for df in ingested_datafiles: + match_keys = subdict(df.model_dump(), match_fields) + matches = overseer.get_matching_objects(MyTardisObject.DATAFILE, match_keys) + assert len(matches) == 1 + assert matches[0] == df diff --git a/tests/test_overseers.py b/tests/test_overseers.py index 9589fba3..757ed295 100644 --- a/tests/test_overseers.py +++ b/tests/test_overseers.py @@ -67,12 +67,12 @@ def test_get_matches_from_mytardis( status=200, ) - # pylint: disable=protected-access expected_projects = [ IngestedProject.model_validate(proj) for proj in project_response_dict["objects"] ] + # pylint: disable=protected-access retrieved_projects = overseer._get_matches_from_mytardis( object_type, {"name": project_name}, @@ -82,14 +82,12 @@ def test_get_matches_from_mytardis( assert retrieved_projects == expected_projects # pylint: disable=protected-access - assert ( - overseer._get_matches_from_mytardis( - object_type, - {"identifier": project_identifiers[0]}, - ) - == expected_projects + identifier_matches = overseer._get_matches_from_mytardis( + object_type, {"identifier": project_identifiers[0]} ) + assert identifier_matches == expected_projects + @responses.activate def test_get_objects_http_error(