diff --git a/sky/clouds/service_catalog/__init__.py b/sky/clouds/service_catalog/__init__.py index d28b530ff06..842b08375d2 100644 --- a/sky/clouds/service_catalog/__init__.py +++ b/sky/clouds/service_catalog/__init__.py @@ -10,6 +10,9 @@ from sky.clouds.service_catalog.constants import CATALOG_SCHEMA_VERSION from sky.clouds.service_catalog.constants import HOSTED_CATALOG_DIR_URL from sky.utils import resources_utils +from sky.utils import rich_utils +from sky.utils import subprocess_utils +from sky.utils import ux_utils if typing.TYPE_CHECKING: from sky.clouds import cloud @@ -71,9 +74,12 @@ def list_accelerators( Returns: A dictionary of canonical accelerator names mapped to a list of instance type offerings. See usage in cli.py. """ - results = _map_clouds_catalog(clouds, 'list_accelerators', gpus_only, - name_filter, region_filter, quantity_filter, - case_sensitive, all_regions, require_price) + with rich_utils.safe_status( + ux_utils.spinner_message('Listing accelerators')): + results = _map_clouds_catalog(clouds, 'list_accelerators', gpus_only, + name_filter, region_filter, + quantity_filter, case_sensitive, + all_regions, require_price) if not isinstance(results, list): results = [results] ret: Dict[str, diff --git a/sky/utils/rich_utils.py b/sky/utils/rich_utils.py index 6badf621294..d724c968045 100644 --- a/sky/utils/rich_utils.py +++ b/sky/utils/rich_utils.py @@ -1,16 +1,21 @@ """Rich status spinner utils.""" import contextlib import threading -from typing import Union +from typing import Dict, Optional, Union import rich.console as rich_console console = rich_console.Console(soft_wrap=True) _status = None _status_nesting_level = 0 +_main_message = None _logging_lock = threading.RLock() +# Track sub thread progress statuses +_thread_statuses: Dict[int, Optional[str]] = {} +_status_lock = threading.RLock() + class _NoOpConsoleStatus: """An empty class for multi-threaded console.status.""" @@ -35,15 +40,17 @@ class _RevertibleStatus: """A wrapper for status that can revert to previous message after exit.""" def __init__(self, message: str): - if _status is not None: - self.previous_message = _status.status + if _main_message is not None: + self.previous_message = _main_message else: self.previous_message = None self.message = message def __enter__(self): global _status_nesting_level - _status.update(self.message) + global _main_message + _main_message = self.message + refresh() _status_nesting_level += 1 _status.__enter__() return _status @@ -57,10 +64,15 @@ def __exit__(self, exc_type, exc_val, exc_tb): _status.__exit__(exc_type, exc_val, exc_tb) _status = None else: - _status.update(self.previous_message) + global _main_message + _main_message = self.previous_message + refresh() def update(self, *args, **kwargs): _status.update(*args, **kwargs) + global _main_message + _main_message = _status.status + refresh() def stop(self): _status.stop() @@ -69,16 +81,65 @@ def start(self): _status.start() -def safe_status(msg: str) -> Union['rich_console.Status', _NoOpConsoleStatus]: +class _ThreadStatus: + """A wrapper of sub thread status""" + + def __init__(self, message: str): + self.thread_id = threading.get_ident() + self.message = message + self.previous_message = _thread_statuses.get(self.thread_id) + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.previous_message is not None: + _thread_statuses[self.thread_id] = self.previous_message + else: + # No previous message, remove the thread status + if self.thread_id in _thread_statuses: + del _thread_statuses[self.thread_id] + refresh() + + def update(self, new_message: str): + self.message = new_message + _thread_statuses[self.thread_id] = new_message + refresh() + + def stop(self): + _thread_statuses[self.thread_id] = None + refresh() + + def start(self): + _thread_statuses[self.thread_id] = self.message + refresh() + + +def refresh(): + """Refresh status to include all thread statuses.""" + if _status is None or _main_message is None: + return + with _status_lock: + msg = _main_message + for v in _thread_statuses.values(): + if v is not None: + msg = msg + f'\n └─ {v}' + _status.update(msg) + + +def safe_status(msg: str) -> Union['rich_console.Status', '_NoOpConsoleStatus']: """A wrapper for multi-threaded console.status.""" from sky import sky_logging # pylint: disable=import-outside-toplevel global _status - if (threading.current_thread() is threading.main_thread() and - not sky_logging.is_silent()): + if sky_logging.is_silent(): + return _NoOpConsoleStatus() + if threading.current_thread() is threading.main_thread(): if _status is None: _status = console.status(msg, refresh_per_second=8) return _RevertibleStatus(msg) - return _NoOpConsoleStatus() + else: + return _ThreadStatus(msg) def stop_safe_status():