Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[UX]: support sub thread status attaching #4493

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions sky/clouds/service_catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from sky.clouds.service_catalog.constants import CATALOG_SCHEMA_VERSION
from sky.clouds.service_catalog.constants import HOSTED_CATALOG_DIR_URL
from sky.utils import resources_utils
from sky.utils import rich_utils
from sky.utils import subprocess_utils
from sky.utils import ux_utils

if typing.TYPE_CHECKING:
from sky.clouds import cloud
Expand Down Expand Up @@ -71,9 +74,12 @@ def list_accelerators(
Returns: A dictionary of canonical accelerator names mapped to a list
of instance type offerings. See usage in cli.py.
"""
results = _map_clouds_catalog(clouds, 'list_accelerators', gpus_only,
name_filter, region_filter, quantity_filter,
case_sensitive, all_regions, require_price)
with rich_utils.safe_status(
ux_utils.spinner_message('Listing accelerators')):
results = _map_clouds_catalog(clouds, 'list_accelerators', gpus_only,
name_filter, region_filter,
quantity_filter, case_sensitive,
all_regions, require_price)
if not isinstance(results, list):
results = [results]
ret: Dict[str,
Expand Down
79 changes: 70 additions & 9 deletions sky/utils/rich_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
"""Rich status spinner utils."""
import contextlib
import threading
from typing import Union
from typing import Dict, Optional, Union

import rich.console as rich_console

console = rich_console.Console(soft_wrap=True)
_status = None
_status_nesting_level = 0
_main_message = None

_logging_lock = threading.RLock()

# Track sub thread progress statuses
_thread_statuses: Dict[int, Optional[str]] = {}
_status_lock = threading.RLock()


class _NoOpConsoleStatus:
"""An empty class for multi-threaded console.status."""
Expand All @@ -35,15 +40,17 @@ class _RevertibleStatus:
"""A wrapper for status that can revert to previous message after exit."""

def __init__(self, message: str):
if _status is not None:
self.previous_message = _status.status
if _main_message is not None:
self.previous_message = _main_message
else:
self.previous_message = None
self.message = message

def __enter__(self):
global _status_nesting_level
_status.update(self.message)
global _main_message
_main_message = self.message
refresh()
_status_nesting_level += 1
_status.__enter__()
return _status
Expand All @@ -57,10 +64,15 @@ def __exit__(self, exc_type, exc_val, exc_tb):
_status.__exit__(exc_type, exc_val, exc_tb)
_status = None
else:
_status.update(self.previous_message)
global _main_message
_main_message = self.previous_message
refresh()

def update(self, *args, **kwargs):
_status.update(*args, **kwargs)
global _main_message
_main_message = _status.status
refresh()

def stop(self):
_status.stop()
Expand All @@ -69,16 +81,65 @@ def start(self):
_status.start()


def safe_status(msg: str) -> Union['rich_console.Status', _NoOpConsoleStatus]:
class _ThreadStatus:
"""A wrapper of sub thread status"""

def __init__(self, message: str):
self.thread_id = threading.get_ident()
self.message = message
self.previous_message = _thread_statuses.get(self.thread_id)

def __enter__(self):
self.start()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if self.previous_message is not None:
_thread_statuses[self.thread_id] = self.previous_message
else:
# No previous message, remove the thread status
if self.thread_id in _thread_statuses:
del _thread_statuses[self.thread_id]
refresh()

def update(self, new_message: str):
self.message = new_message
_thread_statuses[self.thread_id] = new_message
refresh()

def stop(self):
_thread_statuses[self.thread_id] = None
refresh()

def start(self):
_thread_statuses[self.thread_id] = self.message
refresh()


def refresh():
"""Refresh status to include all thread statuses."""
if _status is None or _main_message is None:
return
with _status_lock:
msg = _main_message
for v in _thread_statuses.values():
if v is not None:
msg = msg + f'\n └─ {v}'
_status.update(msg)


def safe_status(msg: str) -> Union['rich_console.Status', '_NoOpConsoleStatus']:
"""A wrapper for multi-threaded console.status."""
from sky import sky_logging # pylint: disable=import-outside-toplevel
global _status
if (threading.current_thread() is threading.main_thread() and
not sky_logging.is_silent()):
if sky_logging.is_silent():
return _NoOpConsoleStatus()
if threading.current_thread() is threading.main_thread():
if _status is None:
_status = console.status(msg, refresh_per_second=8)
return _RevertibleStatus(msg)
return _NoOpConsoleStatus()
else:
return _ThreadStatus(msg)


def stop_safe_status():
Expand Down
Loading