From 409c24fc4f6431e324728602aae89e97d6d80942 Mon Sep 17 00:00:00 2001 From: David Dotson Date: Thu, 29 Jun 2023 19:57:51 -0700 Subject: [PATCH 1/3] Added LRU caching to expensive `AlchemiscaleClient` methods We add LRU caching to the following methods: - `get_network` - `get_transformation` - `get_chemicalsystem` and to the underlying `ProtocolDAGResult` retrieval method for: - `get_transformation_results` - `get_transformation_failures` - `get_task_results` - `get_task_failures` --- alchemiscale/interface/client.py | 193 +++++++++++++++++++------------ 1 file changed, 118 insertions(+), 75 deletions(-) diff --git a/alchemiscale/interface/client.py b/alchemiscale/interface/client.py index 7800982c..bc746e3c 100644 --- a/alchemiscale/interface/client.py +++ b/alchemiscale/interface/client.py @@ -9,8 +9,10 @@ import json from itertools import chain from collections import Counter +from functools import lru_cache import httpx +from async_lru import alru_cache import networkx as nx from gufe import AlchemicalNetwork, Transformation, ChemicalSystem from gufe.tokenization import GufeTokenizable, JSON_HANDLER, GufeKey @@ -216,8 +218,9 @@ def get_chemicalsystem_transformations( f"/chemicalsystems/{chemicalsystem}/transformations" ) + @lru_cache(maxsize=100) def get_network( - self, network: Union[ScopedKey, str], compress: bool = True + self, network: Union[ScopedKey, str], compress: bool = True, visualize: bool = True ) -> AlchemicalNetwork: """Retrieve an AlchemicalNetwork given its ScopedKey. @@ -231,6 +234,8 @@ def get_network( on the bandwidth of your connection to the API service. Set to ``False`` to retrieve without compressing. This is a performance optimization; it has no bearing on the result of this method call. + visualize + If ``True``, show retrieval progress indicator. Returns ------- @@ -238,23 +243,28 @@ def get_network( The retrieved AlchemicalNetwork. """ - from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn, TextColumn - - with Progress(*self._rich_waiting_columns(), transient=False) as progress: - task = progress.add_task( - f"Retrieving [bold]'{network}'[/bold]...", total=None - ) + if visualize: + from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn, TextColumn + with Progress(*self._rich_waiting_columns(), transient=False) as progress: + task = progress.add_task( + f"Retrieving [bold]'{network}'[/bold]...", total=None + ) + an = json_to_gufe( + self._get_resource(f"/networks/{network}", compress=compress) + ) + progress.start_task(task) + progress.update(task, total=1, completed=1) + else: an = json_to_gufe( - self._get_resource(f"/networks/{network}", compress=compress) - ) - progress.start_task(task) - progress.update(task, total=1, completed=1) + self._get_resource(f"/networks/{network}", compress=compress) + ) return an + @lru_cache(maxsize=10000) def get_transformation( - self, transformation: Union[ScopedKey, str], compress: bool = True + self, transformation: Union[ScopedKey, str], compress: bool = True, visualize: bool = True ) -> Transformation: """Retrieve a Transformation given its ScopedKey. @@ -268,6 +278,8 @@ def get_transformation( on the bandwidth of your connection to the API service. Set to ``False`` to retrieve without compressing. This is a performance optimization; it has no bearing on the result of this method call. + visualize + If ``True``, show retrieval progress indicator. Returns ------- @@ -275,25 +287,33 @@ def get_transformation( The retrieved Transformation. """ - from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn + if visualize: + from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn - with Progress(*self._rich_waiting_columns(), transient=False) as progress: - task = progress.add_task( - f"Retrieving [bold]'{transformation}'[/bold]...", total=None - ) + with Progress(*self._rich_waiting_columns(), transient=False) as progress: + task = progress.add_task( + f"Retrieving [bold]'{transformation}'[/bold]...", total=None + ) + tf = json_to_gufe( + self._get_resource( + f"/transformations/{transformation}", compress=compress + ) + ) + progress.start_task(task) + progress.update(task, total=1, completed=1) + else: tf = json_to_gufe( - self._get_resource( - f"/transformations/{transformation}", compress=compress + self._get_resource( + f"/transformations/{transformation}", compress=compress + ) ) - ) - progress.start_task(task) - progress.update(task, total=1, completed=1) return tf + @lru_cache(maxsize=1000) def get_chemicalsystem( - self, chemicalsystem: Union[ScopedKey, str], compress: bool = True + self, chemicalsystem: Union[ScopedKey, str], compress: bool = True, visualize: bool = True ) -> ChemicalSystem: """Retrieve a ChemicalSystem given its ScopedKey. @@ -307,6 +327,8 @@ def get_chemicalsystem( on the bandwidth of your connection to the API service. Set to ``False`` to retrieve without compressing. This is a performance optimization; it has no bearing on the result of this method call. + visualize + If ``True``, show retrieval progress indicator. Returns ------- @@ -314,21 +336,28 @@ def get_chemicalsystem( The retrieved ChemicalSystem. """ - from rich.progress import Progress + if visualize: + from rich.progress import Progress - with Progress(*self._rich_waiting_columns(), transient=False) as progress: - task = progress.add_task( - f"Retrieving [bold]'{chemicalsystem}'[/bold]...", total=None - ) + with Progress(*self._rich_waiting_columns(), transient=False) as progress: + task = progress.add_task( + f"Retrieving [bold]'{chemicalsystem}'[/bold]...", total=None + ) - cs = json_to_gufe( - self._get_resource( - f"/chemicalsystems/{chemicalsystem}", compress=compress + cs = json_to_gufe( + self._get_resource( + f"/chemicalsystems/{chemicalsystem}", compress=compress + ) ) - ) - progress.start_task(task) - progress.update(task, total=1, completed=1) + progress.start_task(task) + progress.update(task, total=1, completed=1) + else: + cs = json_to_gufe( + self._get_resource( + f"/chemicalsystems/{chemicalsystem}", compress=compress + ) + ) return cs @@ -452,7 +481,6 @@ def get_task_transformation(self, task: ScopedKey) -> ScopedKey: def _visualize_status(self, status_counts, status_object): from rich import print as rprint - from rich.table import Table title = f"{status_object}" @@ -754,55 +782,61 @@ def set_tasks_priority(self, tasks: List[ScopedKey], priority: int): ### results + @alru_cache(maxsize=10000) + async def _async_get_protocoldagresult(self, pdr_key, scope, transformation, route, compress): + pdr_sk = ScopedKey( + gufe_key=GufeKey(pdr_key), **Scope.from_str(scope).dict() + ) + + pdr_json = await self._get_resource_async( + f"/transformations/{transformation}/{route}/{pdr_sk}", compress=compress + ) + + pdr = GufeTokenizable.from_dict( + json.loads(pdr_json[0], cls=JSON_HANDLER.decoder) + ) + + return pdr + def _get_protocoldagresults( self, protocoldagresultrefs: List[Dict], transformation: ScopedKey, ok: bool, compress: bool = True, + visualize: bool = True ): - from rich.progress import Progress - if ok: route = "results" else: route = "failures" - async def async_get_protocoldagresult(protocoldagresultref): - pdr_key = protocoldagresultref["obj_key"] - scope = protocoldagresultref["scope"] - - pdr_sk = ScopedKey( - gufe_key=GufeKey(pdr_key), **Scope.from_str(scope).dict() - ) - - pdr_json = await self._get_resource_async( - f"/transformations/{transformation}/{route}/{pdr_sk}", compress=compress - ) - - pdr = GufeTokenizable.from_dict( - json.loads(pdr_json[0], cls=JSON_HANDLER.decoder) - ) - - return pdr - @use_session async def async_request(self): - with Progress(*self._rich_progress_columns(), transient=False) as progress: - task = progress.add_task( - f"Retrieving [bold]ProtocolDAGResult[/bold]s", - total=len(protocoldagresultrefs), - ) - + if visualize: + from rich.progress import Progress + with Progress(*self._rich_progress_columns(), transient=False) as progress: + task = progress.add_task( + f"Retrieving [bold]ProtocolDAGResult[/bold]s", + total=len(protocoldagresultrefs), + ) + + coros = [ + self._async_get_protocoldagresult(protocoldagresultref['obj_key'], protocoldagresultref['scope'], transformation, route, compress) + for protocoldagresultref in protocoldagresultrefs + ] + pdrs = [] + for coro in asyncio.as_completed(coros): + pdr = await coro + pdrs.append(pdr) + progress.update(task, advance=1) + progress.refresh() + else: coros = [ - async_get_protocoldagresult(protocoldagresultref) + self._async_get_protocoldagresult(protocoldagresultref['obj_key'], protocoldagresultref['scope'], transformation, route, compress) for protocoldagresultref in protocoldagresultrefs ] - pdrs = [] - for coro in asyncio.as_completed(coros): - pdr = await coro - pdrs.append(pdr) - progress.update(task, advance=1) + pdrs = await asyncio.gather(*coros) return pdrs @@ -821,6 +855,7 @@ def get_transformation_results( transformation: ScopedKey, return_protocoldagresults: bool = False, compress: bool = True, + visualize: bool = True, ) -> Union[Optional[ProtocolResult], List[ProtocolDAGResult]]: """Get a `ProtocolResult` for the given `Transformation`. @@ -847,12 +882,14 @@ def get_transformation_results( on the bandwidth of your connection to the API service. Set to ``False`` to retrieve without compressing. This is a performance optimization; it has no bearing on the result of this method call. + visualize + If ``True``, show retrieval progress indicators. """ if not return_protocoldagresults: # get the transformation if we intend to return a ProtocolResult - tf: Transformation = self.get_transformation(transformation) + tf: Transformation = self.get_transformation(transformation, visualize=visualize) # get all protocoldagresultrefs for the given transformation protocoldagresultrefs = self._get_resource( @@ -860,7 +897,7 @@ def get_transformation_results( ) pdrs = self._get_protocoldagresults( - protocoldagresultrefs, transformation, ok=True, compress=compress + protocoldagresultrefs, transformation, ok=True, compress=compress, visualize=visualize ) if return_protocoldagresults: @@ -872,7 +909,7 @@ def get_transformation_results( return None def get_transformation_failures( - self, transformation: ScopedKey, compress: bool = True + self, transformation: ScopedKey, compress: bool = True, visualize: bool = True ) -> List[ProtocolDAGResult]: """Get failed `ProtocolDAGResult`\s for the given `Transformation`. @@ -886,6 +923,8 @@ def get_transformation_failures( on the bandwidth of your connection to the API service. Set to ``False`` to retrieve without compressing. This is a performance optimization; it has no bearing on the result of this method call. + visualize + If ``True``, show retrieval progress indicators. """ # get all protocoldagresultrefs for the given transformation @@ -894,13 +933,13 @@ def get_transformation_failures( ) pdrs = self._get_protocoldagresults( - protocoldagresultrefs, transformation, ok=False, compress=compress + protocoldagresultrefs, transformation, ok=False, compress=compress, visualize=visualize ) return pdrs def get_task_results( - self, task: ScopedKey, compress: bool = True + self, task: ScopedKey, compress: bool = True, visualize: bool = True ) -> List[ProtocolDAGResult]: """Get successful `ProtocolDAGResult`s for the given `Task`. @@ -914,6 +953,8 @@ def get_task_results( on the bandwidth of your connection to the API service. Set to ``False`` to retrieve without compressing. This is a performance optimization; it has no bearing on the result of this method call. + visualize + If ``True``, show retrieval progress indicators. """ # first, get the transformation; also confirms it exists @@ -925,13 +966,13 @@ def get_task_results( ) pdrs = self._get_protocoldagresults( - protocoldagresultrefs, transformation, ok=True, compress=compress + protocoldagresultrefs, transformation, ok=True, compress=compress, visualize=visualize ) return pdrs def get_task_failures( - self, task: ScopedKey, compress: bool = True + self, task: ScopedKey, compress: bool = True, visualize: bool = True ) -> List[ProtocolDAGResult]: """Get failed `ProtocolDAGResult`s for the given `Task`. @@ -945,6 +986,8 @@ def get_task_failures( on the bandwidth of your connection to the API service. Set to ``False`` to retrieve without compressing. This is a performance optimization; it has no bearing on the result of this method call. + visualize + If ``True``, show retrieval progress indicators. """ # first, get the transformation; also confirms it exists @@ -956,7 +999,7 @@ def get_task_failures( ) pdrs = self._get_protocoldagresults( - protocoldagresultrefs, transformation, ok=False, compress=compress + protocoldagresultrefs, transformation, ok=False, compress=compress, visualize=visualize ) return pdrs From d927ce0219db4feaf0e1693e24f76815d17861fd Mon Sep 17 00:00:00 2001 From: David Dotson Date: Thu, 29 Jun 2023 20:02:47 -0700 Subject: [PATCH 2/3] Black! --- alchemiscale/interface/client.py | 104 ++++++++++++++++++++++--------- 1 file changed, 76 insertions(+), 28 deletions(-) diff --git a/alchemiscale/interface/client.py b/alchemiscale/interface/client.py index bc746e3c..71e06f93 100644 --- a/alchemiscale/interface/client.py +++ b/alchemiscale/interface/client.py @@ -220,7 +220,10 @@ def get_chemicalsystem_transformations( @lru_cache(maxsize=100) def get_network( - self, network: Union[ScopedKey, str], compress: bool = True, visualize: bool = True + self, + network: Union[ScopedKey, str], + compress: bool = True, + visualize: bool = True, ) -> AlchemicalNetwork: """Retrieve an AlchemicalNetwork given its ScopedKey. @@ -244,7 +247,13 @@ def get_network( """ if visualize: - from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn, TextColumn + from rich.progress import ( + Progress, + SpinnerColumn, + TimeElapsedColumn, + TextColumn, + ) + with Progress(*self._rich_waiting_columns(), transient=False) as progress: task = progress.add_task( f"Retrieving [bold]'{network}'[/bold]...", total=None @@ -257,14 +266,17 @@ def get_network( progress.update(task, total=1, completed=1) else: an = json_to_gufe( - self._get_resource(f"/networks/{network}", compress=compress) - ) + self._get_resource(f"/networks/{network}", compress=compress) + ) return an @lru_cache(maxsize=10000) def get_transformation( - self, transformation: Union[ScopedKey, str], compress: bool = True, visualize: bool = True + self, + transformation: Union[ScopedKey, str], + compress: bool = True, + visualize: bool = True, ) -> Transformation: """Retrieve a Transformation given its ScopedKey. @@ -304,16 +316,19 @@ def get_transformation( progress.update(task, total=1, completed=1) else: tf = json_to_gufe( - self._get_resource( - f"/transformations/{transformation}", compress=compress - ) + self._get_resource( + f"/transformations/{transformation}", compress=compress ) + ) return tf @lru_cache(maxsize=1000) def get_chemicalsystem( - self, chemicalsystem: Union[ScopedKey, str], compress: bool = True, visualize: bool = True + self, + chemicalsystem: Union[ScopedKey, str], + compress: bool = True, + visualize: bool = True, ) -> ChemicalSystem: """Retrieve a ChemicalSystem given its ScopedKey. @@ -354,10 +369,10 @@ def get_chemicalsystem( progress.update(task, total=1, completed=1) else: cs = json_to_gufe( - self._get_resource( - f"/chemicalsystems/{chemicalsystem}", compress=compress - ) + self._get_resource( + f"/chemicalsystems/{chemicalsystem}", compress=compress ) + ) return cs @@ -783,10 +798,10 @@ def set_tasks_priority(self, tasks: List[ScopedKey], priority: int): ### results @alru_cache(maxsize=10000) - async def _async_get_protocoldagresult(self, pdr_key, scope, transformation, route, compress): - pdr_sk = ScopedKey( - gufe_key=GufeKey(pdr_key), **Scope.from_str(scope).dict() - ) + async def _async_get_protocoldagresult( + self, pdr_key, scope, transformation, route, compress + ): + pdr_sk = ScopedKey(gufe_key=GufeKey(pdr_key), **Scope.from_str(scope).dict()) pdr_json = await self._get_resource_async( f"/transformations/{transformation}/{route}/{pdr_sk}", compress=compress @@ -804,7 +819,7 @@ def _get_protocoldagresults( transformation: ScopedKey, ok: bool, compress: bool = True, - visualize: bool = True + visualize: bool = True, ): if ok: route = "results" @@ -815,14 +830,23 @@ def _get_protocoldagresults( async def async_request(self): if visualize: from rich.progress import Progress - with Progress(*self._rich_progress_columns(), transient=False) as progress: + + with Progress( + *self._rich_progress_columns(), transient=False + ) as progress: task = progress.add_task( f"Retrieving [bold]ProtocolDAGResult[/bold]s", total=len(protocoldagresultrefs), ) coros = [ - self._async_get_protocoldagresult(protocoldagresultref['obj_key'], protocoldagresultref['scope'], transformation, route, compress) + self._async_get_protocoldagresult( + protocoldagresultref["obj_key"], + protocoldagresultref["scope"], + transformation, + route, + compress, + ) for protocoldagresultref in protocoldagresultrefs ] pdrs = [] @@ -833,7 +857,13 @@ async def async_request(self): progress.refresh() else: coros = [ - self._async_get_protocoldagresult(protocoldagresultref['obj_key'], protocoldagresultref['scope'], transformation, route, compress) + self._async_get_protocoldagresult( + protocoldagresultref["obj_key"], + protocoldagresultref["scope"], + transformation, + route, + compress, + ) for protocoldagresultref in protocoldagresultrefs ] pdrs = await asyncio.gather(*coros) @@ -889,7 +919,9 @@ def get_transformation_results( if not return_protocoldagresults: # get the transformation if we intend to return a ProtocolResult - tf: Transformation = self.get_transformation(transformation, visualize=visualize) + tf: Transformation = self.get_transformation( + transformation, visualize=visualize + ) # get all protocoldagresultrefs for the given transformation protocoldagresultrefs = self._get_resource( @@ -897,7 +929,11 @@ def get_transformation_results( ) pdrs = self._get_protocoldagresults( - protocoldagresultrefs, transformation, ok=True, compress=compress, visualize=visualize + protocoldagresultrefs, + transformation, + ok=True, + compress=compress, + visualize=visualize, ) if return_protocoldagresults: @@ -909,7 +945,7 @@ def get_transformation_results( return None def get_transformation_failures( - self, transformation: ScopedKey, compress: bool = True, visualize: bool = True + self, transformation: ScopedKey, compress: bool = True, visualize: bool = True ) -> List[ProtocolDAGResult]: """Get failed `ProtocolDAGResult`\s for the given `Transformation`. @@ -933,13 +969,17 @@ def get_transformation_failures( ) pdrs = self._get_protocoldagresults( - protocoldagresultrefs, transformation, ok=False, compress=compress, visualize=visualize + protocoldagresultrefs, + transformation, + ok=False, + compress=compress, + visualize=visualize, ) return pdrs def get_task_results( - self, task: ScopedKey, compress: bool = True, visualize: bool = True + self, task: ScopedKey, compress: bool = True, visualize: bool = True ) -> List[ProtocolDAGResult]: """Get successful `ProtocolDAGResult`s for the given `Task`. @@ -966,13 +1006,17 @@ def get_task_results( ) pdrs = self._get_protocoldagresults( - protocoldagresultrefs, transformation, ok=True, compress=compress, visualize=visualize + protocoldagresultrefs, + transformation, + ok=True, + compress=compress, + visualize=visualize, ) return pdrs def get_task_failures( - self, task: ScopedKey, compress: bool = True, visualize: bool = True + self, task: ScopedKey, compress: bool = True, visualize: bool = True ) -> List[ProtocolDAGResult]: """Get failed `ProtocolDAGResult`s for the given `Task`. @@ -999,7 +1043,11 @@ def get_task_failures( ) pdrs = self._get_protocoldagresults( - protocoldagresultrefs, transformation, ok=False, compress=compress, visualize=visualize + protocoldagresultrefs, + transformation, + ok=False, + compress=compress, + visualize=visualize, ) return pdrs From 1e9094f6f382ef835682b1eccfb7a6acf0d8ade1 Mon Sep 17 00:00:00 2001 From: David Dotson Date: Thu, 29 Jun 2023 20:09:53 -0700 Subject: [PATCH 3/3] Added async_lru to deps where needed --- devtools/conda-envs/alchemiscale-client.yml | 1 + devtools/conda-envs/docs.yml | 1 + devtools/conda-envs/test.yml | 1 + 3 files changed, 3 insertions(+) diff --git a/devtools/conda-envs/alchemiscale-client.yml b/devtools/conda-envs/alchemiscale-client.yml index 0f62d811..0409780d 100644 --- a/devtools/conda-envs/alchemiscale-client.yml +++ b/devtools/conda-envs/alchemiscale-client.yml @@ -27,5 +27,6 @@ dependencies: - pip: - nest_asyncio + - async_lru - git+https://github.com/openforcefield/alchemiscale.git@v0.1.2 - git+https://github.com/choderalab/perses.git@protocol-neqcyc diff --git a/devtools/conda-envs/docs.yml b/devtools/conda-envs/docs.yml index 0e7263ab..8edbc899 100644 --- a/devtools/conda-envs/docs.yml +++ b/devtools/conda-envs/docs.yml @@ -63,6 +63,7 @@ dependencies: - sphinx_rtd_theme - pip: + - async_lru - git+https://github.com/dotsdl/grolt@relax-cryptography # neo4j test server deployment - git+https://github.com/OpenFreeEnergy/gufe - git+https://github.com/OpenFreeEnergy/openfe diff --git a/devtools/conda-envs/test.yml b/devtools/conda-envs/test.yml index aac7b47c..892968c3 100644 --- a/devtools/conda-envs/test.yml +++ b/devtools/conda-envs/test.yml @@ -59,6 +59,7 @@ dependencies: - openmmforcefields - pip: + - async_lru - git+https://github.com/dotsdl/grolt@relax-cryptography # neo4j test server deployment - git+https://github.com/OpenFreeEnergy/gufe - git+https://github.com/OpenFreeEnergy/openfe