Skip to content

Commit

Permalink
Implement bulk prefetching of datafile metadata, using batched GET re…
Browse files Browse the repository at this point in the history
…quests (#498)

* Implement bulk prefetching of datafile metadata, using batched GET requests

* Simplify the use of validator dataclasses for representing GET request results

* Ensure the keys used in the endpoint cache are hashable

* Simplify the CacheKey class

* Inline a class which isn't needed

* Add initial tests for Overseer caching. Still need some work

* Extract fixture to a common fixture location

* Improve the overseer caching tests

* Reinstate older dataclass for storing a GET response

* Fix incorrect endpoint when prefetching datafiles

* Ensure we use a valid key for prefetching datafiles (MyTardis requires a URI/ID instead of identifier string)

* Enable optional (on-by-default) request caching in the MyTardis client

* Restrict use of caching to GET requests

* Set the requests_cache logging level to INFO by default

* Fix logging of already-ingested datafile

* Add more logging in the Overseer

* Make the Overseer cache debug logging less verbose

* Log the correct cache keys

* Ensure we serialize directory the same way in both Datafile and IngestedDatafile

* Don't cache response data with no objects

* Remove caching of smaller response data in the Overseer (instead rely on use requests-cache in MyTardis client). Now we only use the Overseer caching for the prefetch data.

* Add a log call

* Avoid caching responses from datafile GET requests, as we don't use them and they add a lot of data to the cache

* Don't log cache exclusions by default

* Merge objects into the Overseer cache if there's a key collision, instead of raising an error. Also fix a logging inefficiency.

* Log a warning if there's a cache key collision
  • Loading branch information
andrew-uoa authored Aug 29, 2024
1 parent c74606d commit 4bf7635
Show file tree
Hide file tree
Showing 17 changed files with 626 additions and 74 deletions.
71 changes: 70 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pytz = "^2024.1"
types-pytz = "^2024.1.0.20240417"
python-slugify = "^8.0.1"
pydantic-settings = "^2.3.4"
requests-cache = "^1.2.1"
validators = "^0.33.0"
typer = "^0.12.0"
rocrate = "^0.10.0"
Expand Down
2 changes: 2 additions & 0 deletions src/cli/cmd_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def upload(
"Manifest directory is empty. Extract data into a manifest using 'extract' command."
)

logging.info("Loading metadata manifest from %s", manifest_dir)

manifest = IngestionManifest.deserialize(manifest_dir)

logging.info("Successfully loaded metadata manifest from %s", manifest_dir)
Expand Down
19 changes: 17 additions & 2 deletions src/ingestion_factory/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ def ingest_datafiles(
result = IngestionResult()
datafiles: list[Datafile] = []

datafiles_prefetched: dict[URI, bool] = {}

for raw_datafile in raw_datafiles:
refined_datafile = self.smelter.smelt_datafile(raw_datafile)
if not refined_datafile:
Expand All @@ -217,6 +219,18 @@ def ingest_datafiles(
if not datafile:
result.error.append(refined_datafile.display_name)
continue

if not datafiles_prefetched.get(datafile.dataset):
num_objects = self._overseer.prefetch(
"/dataset_file", query_params={"dataset": datafile.dataset}
)
datafiles_prefetched[datafile.dataset] = True
logging.info(
"Prefetched %d datafiles for dataset %s",
num_objects,
datafile.dataset,
)

# Add a replica to represent the copy transferred by the Conveyor.
datafile.replicas.append(self.conveyor.create_replica(datafile))

Expand All @@ -226,8 +240,9 @@ def ingest_datafiles(
)
if len(matching_datafiles) > 0:
logging.info(
'Already ingested datafile "%s". Skipping datafile ingestion.',
datafile.directory,
'Already ingested datafile "%s" as %s. Skipping datafile ingestion.',
datafile.filepath,
matching_datafiles[0].resource_uri,
)
result.skipped.append((datafile.display_name, None))
continue
Expand Down
8 changes: 2 additions & 6 deletions src/mytardis_client/endpoint_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Institution,
Instrument,
MyTardisIntrospection,
MyTardisResourceBase,
MyTardisObjectData,
ProjectParameterSet,
StorageBox,
)
Expand All @@ -26,11 +26,7 @@
class GetRequestProperties(BaseModel):
"""Definition of behaviour/structure for a GET request to a MyTardis endpoint."""

# Note: it would be useful here to store the dataclass type for the response, so that
# the response can be validated/deserialized without the requester needing
# to know the correct type. But the dataclasses are currently defined outside the
# mytardis_client module, and this module should ideally be self-contained.
response_obj_type: type[MyTardisResourceBase]
response_obj_type: type[MyTardisObjectData]


class PostRequestProperties(BaseModel):
Expand Down
79 changes: 58 additions & 21 deletions src/mytardis_client/mt_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@

import logging
from copy import deepcopy
from datetime import timedelta
from typing import Any, Callable, Dict, Generic, Literal, Optional, TypeVar
from urllib.parse import urljoin
from urllib.parse import urljoin, urlparse

import requests
from pydantic import BaseModel, ValidationError
from requests import ConnectTimeout, ReadTimeout, RequestException, Response
from requests import ConnectTimeout, ReadTimeout, RequestException, Response, Session
from requests_cache import CachedSession
from tenacity import (
before_sleep_log,
retry,
Expand All @@ -25,13 +27,17 @@
from src.mytardis_client.common_types import HttpRequestMethod
from src.mytardis_client.endpoint_info import get_endpoint_info
from src.mytardis_client.endpoints import URI, MyTardisEndpoint
from src.mytardis_client.response_data import MyTardisResource
from src.mytardis_client.response_data import MyTardisObjectData

# Defines the valid values for the MyTardis API version
MyTardisApiVersion = Literal["v1"]

logger = logging.getLogger(__name__)

# requests_cache is quite verbose in DEBUG - can crowd out other log messages
caching_logger = logging.getLogger("requests_cache")
caching_logger.setLevel(logging.INFO)


def make_api_stub(version: MyTardisApiVersion) -> str:
"""Creates a stub for the MyTardis API URL
Expand Down Expand Up @@ -76,23 +82,14 @@ class GetResponseMeta(BaseModel):
previous: Optional[str]


MyTardisObjectData = TypeVar("MyTardisObjectData", bound=BaseModel)
T = TypeVar("T", bound=BaseModel)


class GetResponse(BaseModel, Generic[MyTardisObjectData]):
"""A Pydantic model to handle the response from a GET request to the MyTardis API"""
class GetResponse(BaseModel, Generic[T]):
"""Data model for a response from a GET request to the MyTardis API"""

meta: GetResponseMeta
objects: list[MyTardisObjectData]


class Ingested(BaseModel, Generic[MyTardisObjectData]):
"""A Pydantic model to store the data of an ingested object, i.e. the response from a GET
request to the MyTardis API, along with the URI.
"""

obj: MyTardisObjectData
resource_uri: URI
objects: list[T]


def sanitize_params(params: dict[str, Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -132,6 +129,29 @@ def replace_uri_with_id(value: Any) -> Any:
return updated_params


caching_logger = logging.getLogger("requests_cache")


def make_endpoint_filter(
endpoints: tuple[MyTardisEndpoint],
) -> Callable[[requests.Response], bool]:
"""Create a filter predicate to exclude responses from certain endpoints from
being cached.
"""

def retain_response(response: requests.Response) -> bool:
path = urlparse(response.url).path.rstrip("/")
for endpoint in endpoints:
if path.endswith(endpoint):
caching_logger.debug(
"Request cache filter excluded response from: %s", response.url
)
return False
return True

return retain_response


class MyTardisRESTFactory:
"""Class to interact with MyTardis by calling the REST API
Expand All @@ -156,6 +176,7 @@ def __init__(
auth: AuthConfig,
connection: ConnectionConfig,
request_timeout: int = 30,
use_cache: bool = True,
) -> None:
"""MyTardisRESTFactory initialisation using a configuration dictionary.
Expand All @@ -182,7 +203,18 @@ def __init__(
self._url_base = urljoin(self._hostname, self._api_stub)

self.user_agent = f"{self.user_agent_name}/2.0 ({self.user_agent_url})"
self._session = requests.Session()

# Don't cache datafile responses as they are voluminous and not likely to be reused.
self._session = (
CachedSession(
backend="memory",
expire_after=timedelta(hours=1),
allowable_methods=("GET",),
filter_fn=make_endpoint_filter(("/dataset_file",)),
)
if use_cache
else Session()
)

self._request_timeout = request_timeout

Expand Down Expand Up @@ -281,7 +313,7 @@ def get(
endpoint: MyTardisEndpoint,
query_params: Optional[dict[str, Any]] = None,
meta_params: Optional[GetRequestMetaParams] = None,
) -> tuple[list[MyTardisResource], GetResponseMeta]:
) -> tuple[list[MyTardisObjectData], GetResponseMeta]:
"""Submit a GET request to the MyTardis API and return the response as a list of objects.
Note that the response is paginated, so the function may not return all objects matching
Expand Down Expand Up @@ -309,7 +341,7 @@ def get(

response_meta = GetResponseMeta.model_validate(response_json["meta"])

objects: list[MyTardisResource] = []
objects: list[MyTardisObjectData] = []

response_objects = response_json.get("objects")
if response_objects is None:
Expand All @@ -336,15 +368,15 @@ def get_all(
endpoint: MyTardisEndpoint,
query_params: Optional[dict[str, Any]] = None,
batch_size: int = 500,
) -> tuple[list[MyTardisResource], int]:
) -> tuple[list[MyTardisObjectData], int]:
"""Get all objects of the given type that match 'query_params'.
Sends repeated GET requests to the MyTardis API until all objects have been retrieved.
The 'batch_size' argument can be used to control the number of objects retrieved in
each request
"""

objects: list[MyTardisResource] = []
objects: list[MyTardisObjectData] = []

while True:
request_meta = GetRequestMetaParams(limit=batch_size, offset=len(objects))
Expand All @@ -364,3 +396,8 @@ def get_all(
break

return objects, response_meta.total_count

def clear_cache(self) -> None:
"""Clear the cache of the requests session"""
if isinstance(self._session, CachedSession):
self._session.cache.clear() # type: ignore[no-untyped-call]
1 change: 1 addition & 0 deletions src/mytardis_client/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class MyTardisObject(str, Enum):
REPLICA = "replica"
SCHEMA = "schema"
STORAGE_BOX = "storagebox"
STORAGE_BOX_OPTION = "storageboxoption"
USER = "user"


Expand Down
Loading

0 comments on commit 4bf7635

Please sign in to comment.