From efcae39714d96075798a72fede286d7ec3d3acf4 Mon Sep 17 00:00:00 2001 From: Chad Dombrova Date: Sun, 19 May 2024 10:35:57 -0700 Subject: [PATCH] Add type annotations --- pyproject.toml | 17 + src/rez/__init__.py | 3 - src/rez/build_process.py | 70 ++- src/rez/build_system.py | 73 ++- src/rez/cli/_complete_util.py | 3 +- src/rez/cli/_main.py | 4 +- src/rez/cli/_util.py | 5 +- src/rez/cli/benchmark.py | 8 +- src/rez/cli/build.py | 10 +- src/rez/cli/complete.py | 6 +- src/rez/cli/interpret.py | 1 - src/rez/config.py | 31 +- src/rez/deprecations.py | 2 +- src/rez/developer_package.py | 8 +- src/rez/package_bind.py | 4 +- src/rez/package_copy.py | 21 +- src/rez/package_filter.py | 32 +- src/rez/package_maker.py | 15 +- src/rez/package_order.py | 128 +++-- src/rez/package_repository.py | 114 +++-- src/rez/package_resources.py | 82 ++- src/rez/package_test.py | 8 +- src/rez/packages.py | 118 +++-- src/rez/pip.py | 2 +- src/rez/plugin_managers.py | 68 ++- src/rez/release_vcs.py | 22 +- src/rez/resolved_context.py | 154 +++--- src/rez/resolver.py | 48 +- src/rez/rex.py | 24 +- src/rez/shells.py | 67 +-- src/rez/solver.py | 475 ++++++++++-------- src/rez/status.py | 8 +- src/rez/suite.py | 84 +++- src/rez/system.py | 4 +- src/rez/utils/__init__.py | 5 +- src/rez/utils/backcompat.py | 2 +- src/rez/utils/data_utils.py | 100 ++-- src/rez/utils/filesystem.py | 1 + src/rez/utils/patching.py | 4 +- src/rez/utils/platform_.py | 7 +- src/rez/utils/resources.py | 52 +- src/rez/utils/schema.py | 3 +- src/rez/utils/sourcecode.py | 7 +- src/rez/utils/typing.py | 20 + src/rez/version/_requirement.py | 95 ++-- src/rez/version/_version.py | 333 ++++++------ src/rezplugins/build_process/local.py | 36 +- src/rezplugins/build_system/cmake.py | 26 +- src/rezplugins/build_system/custom.py | 26 +- .../package_repository/filesystem.py | 107 ++-- src/rezplugins/package_repository/memory.py | 36 +- src/rezplugins/release_hook/amqp.py | 5 +- 52 files changed, 1569 insertions(+), 1015 deletions(-) create mode 100644 src/rez/utils/typing.py diff --git a/pyproject.toml b/pyproject.toml index fed528d4a..2b5a32a12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,20 @@ [build-system] requires = ["setuptools"] build-backend = "setuptools.build_meta" + +[tool.mypy] +files = ["src/rez/", "src/rezplugins/"] +exclude = [ + '.*/rez/data/.*', + '.*/rez/vendor/.*', + '.*/rez/tests/.*', + '.*/rez/utils/lint_helper.py', +] +disable_error_code = ["var-annotated", "import-not-found"] +check_untyped_defs = true +# allow this for now: +allow_redefinition = true + +[[tool.mypy.overrides]] +module = 'rez.utils.lint_helper' +follow_imports = "skip" diff --git a/src/rez/__init__.py b/src/rez/__init__.py index be0a91c2e..c83f4183a 100644 --- a/src/rez/__init__.py +++ b/src/rez/__init__.py @@ -54,10 +54,7 @@ def callback(sig, frame): txt = ''.join(traceback.format_stack(frame)) print() print(txt) - else: - callback = None - if callback: signal.signal(signal.SIGUSR1, callback) # Register handler diff --git a/src/rez/build_process.py b/src/rez/build_process.py index a7d3e280d..65a0ffa81 100644 --- a/src/rez/build_process.py +++ b/src/rez/build_process.py @@ -2,6 +2,8 @@ # Copyright Contributors to the Rez Project +from __future__ import annotations + from rez.packages import iter_packages from rez.exceptions import BuildProcessError, BuildContextResolveError, \ ReleaseHookCancellingError, RezError, ReleaseError, BuildError, \ @@ -18,6 +20,13 @@ import getpass import os.path import sys +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from rez.build_system import BuildSystem + from rez.packages import Package, Variant + from rez.release_vcs import ReleaseVCS + from rez.developer_package import DeveloperPackage debug_print = config.debug_printer("package_release") @@ -29,9 +38,16 @@ def get_build_process_types(): return plugin_manager.get_plugins('build_process') -def create_build_process(process_type, working_dir, build_system, package=None, - vcs=None, ensure_latest=True, skip_repo_errors=False, - ignore_existing_tag=False, verbose=False, quiet=False): +def create_build_process(process_type: str, + working_dir: str, + build_system: BuildSystem, + package=None, + vcs: ReleaseVCS | None = None, + ensure_latest: bool = True, + skip_repo_errors: bool = False, + ignore_existing_tag: bool = False, + verbose: bool = False, + quiet: bool = False) -> BuildProcess: """Create a :class:`BuildProcess` instance. .. warning:: @@ -44,7 +60,7 @@ def create_build_process(process_type, working_dir, build_system, package=None, if process_type not in process_types: raise BuildProcessError("Unknown build process: %r" % process_type) - cls = plugin_manager.get_plugin_class('build_process', process_type) + cls = plugin_manager.get_plugin_class('build_process', process_type, BuildProcess) return cls(working_dir, # ignored (deprecated) build_system, @@ -77,7 +93,8 @@ class BuildProcess(object): def name(cls): raise NotImplementedError - def __init__(self, working_dir, build_system, package=None, vcs=None, + def __init__(self, working_dir: str, build_system: BuildSystem, package=None, + vcs: ReleaseVCS | None = None, ensure_latest=True, skip_repo_errors=False, ignore_existing_tag=False, verbose=False, quiet=False): """Create a BuildProcess. @@ -119,14 +136,15 @@ def __init__(self, working_dir, build_system, package=None, vcs=None, self.package.config.build_directory) @property - def package(self): + def package(self) -> DeveloperPackage: return self.build_system.package @property - def working_dir(self): + def working_dir(self) -> str: return self.build_system.working_dir - def build(self, install_path=None, clean=False, install=False, variants=None): + def build(self, install_path: str | None = None, clean: bool = False, + install: bool = False, variants: list[int] | None = None) -> int: """Perform the build process. Iterates over the package's variants, resolves the environment for @@ -149,7 +167,8 @@ def build(self, install_path=None, clean=False, install=False, variants=None): """ raise NotImplementedError - def release(self, release_message=None, variants=None): + def release(self, release_message: str | None = None, + variants: list[int] | None = None) -> int: """Perform the release process. Iterates over the package's variants, building and installing each into @@ -167,7 +186,7 @@ def release(self, release_message=None, variants=None): """ raise NotImplementedError - def get_changelog(self): + def get_changelog(self) -> str | None: """Get the changelog since last package release. Returns: @@ -187,7 +206,8 @@ def repo_operation(self): except exc_type as e: print_warning("THE FOLLOWING ERROR WAS SKIPPED:\n%s" % str(e)) - def visit_variants(self, func, variants=None, **kwargs): + def visit_variants(self, func, variants: list[int] | None = None, + **kwargs) -> tuple[int, list[str | None]]: """Iterate over variants and call a function on each.""" if variants: present_variants = range(self.package.num_variants) @@ -215,7 +235,7 @@ def visit_variants(self, func, variants=None, **kwargs): return num_visited, results - def get_package_install_path(self, path): + def get_package_install_path(self, path: str) -> str: """Return the installation path for a package (where its payload goes). Args: @@ -230,7 +250,8 @@ def get_package_install_path(self, path): package_version=self.package.version ) - def create_build_context(self, variant, build_type, build_path): + def create_build_context(self, variant: Variant, build_type: BuildType, + build_path: str) -> tuple[ResolvedContext, str]: """Create a context to build the variant within.""" request = variant.get_requires(build_requires=True, private_build_requires=True) @@ -274,7 +295,7 @@ def create_build_context(self, variant, build_type, build_path): raise BuildContextResolveError(context) return context, rxt_filepath - def pre_release(self): + def pre_release(self) -> None: release_settings = self.package.config.plugins.release_vcs # test that the release path exists @@ -322,7 +343,7 @@ def pre_release(self): else: break - def post_release(self, release_message=None): + def post_release(self, release_message=None) -> None: tag_name = self.get_current_tag_name() if self.vcs is None: @@ -332,7 +353,7 @@ def post_release(self, release_message=None): with self.repo_operation(): self.vcs.create_release_tag(tag_name=tag_name, message=release_message) - def get_current_tag_name(self): + def get_current_tag_name(self) -> str: release_settings = self.package.config.plugins.release_vcs try: tag_name = self.package.format(release_settings.tag_name) @@ -342,7 +363,7 @@ def get_current_tag_name(self): tag_name = "unversioned" return tag_name - def run_hooks(self, hook_event, **kwargs): + def run_hooks(self, hook_event, **kwargs) -> None: hook_names = self.package.config.release_hooks or [] hooks = create_release_hooks(hook_names, self.working_dir) @@ -357,12 +378,12 @@ def run_hooks(self, hook_event, **kwargs): "%s cancelled by %s hook '%s': %s:\n%s" % (hook_event.noun, hook_event.label, hook.name(), e.__class__.__name__, str(e))) - except RezError: + except RezError as e: debug_print("Error in %s hook '%s': %s:\n%s" % (hook_event.label, hook.name(), e.__class__.__name__, str(e))) - def get_previous_release(self): + def get_previous_release(self) -> Package | None: release_path = self.package.config.release_packages_path it = iter_packages(self.package.name, paths=[release_path]) packages = sorted(it, key=lambda x: x.version, reverse=True) @@ -372,7 +393,7 @@ def get_previous_release(self): return package return None - def get_changelog(self): + def get_changelog(self) -> str | None: previous_package = self.get_previous_release() if previous_package: previous_revision = previous_package.revision @@ -380,10 +401,11 @@ def get_changelog(self): previous_revision = None changelog = None - with self.repo_operation(): - changelog = self.vcs.get_changelog( - previous_revision, - max_revisions=config.max_package_changelog_revisions) + if self.vcs: + with self.repo_operation(): + changelog = self.vcs.get_changelog( + previous_revision, + max_revisions=config.max_package_changelog_revisions) return changelog diff --git a/src/rez/build_system.py b/src/rez/build_system.py index 9f27b8e0b..0f741c794 100644 --- a/src/rez/build_system.py +++ b/src/rez/build_system.py @@ -2,13 +2,34 @@ # Copyright Contributors to the Rez Project +from __future__ import annotations + +import argparse import os.path +from typing import TYPE_CHECKING + from rez.build_process import BuildType from rez.exceptions import BuildSystemError from rez.packages import get_developer_package from rez.rex_bindings import VariantBinding +if TYPE_CHECKING: + from typing import TypedDict # not available until python 3.8 + from rez.developer_package import DeveloperPackage + from rez.resolved_context import ResolvedContext + from rez.packages import Package, Variant + from rez.rex import RexExecutor + + # FIXME: move this out of TYPE_CHECKING block when python 3.7 support is dropped + class BuildResult(TypedDict, total=False): + success: bool + extra_files: list[str] + build_env_script: str + +else: + BuildResult = dict + def get_buildsys_types(): """Returns the available build system implementations - cmake, make etc.""" @@ -16,7 +37,8 @@ def get_buildsys_types(): return plugin_manager.get_plugins('build_system') -def get_valid_build_systems(working_dir, package=None): +def get_valid_build_systems(working_dir: str, + package: Package | None = None) -> list[type[BuildSystem]]: """Returns the build system classes that could build the source in given dir. Args: @@ -41,19 +63,19 @@ def get_valid_build_systems(working_dir, package=None): if package: if getattr(package, "build_command", None) is not None: - buildsys_name = "custom" + buildsys_name: str | None = "custom" else: buildsys_name = getattr(package, "build_system", None) # package explicitly specifies build system if buildsys_name: - cls = plugin_manager.get_plugin_class('build_system', buildsys_name) + cls = plugin_manager.get_plugin_class('build_system', buildsys_name, BuildSystem) return [cls] # detect valid build systems clss = [] - for buildsys_name in get_buildsys_types(): - cls = plugin_manager.get_plugin_class('build_system', buildsys_name) + for buildsys_name_ in get_buildsys_types(): + cls = plugin_manager.get_plugin_class('build_system', buildsys_name_, BuildSystem) if cls.is_valid_root(working_dir, package=package): clss.append(cls) @@ -67,9 +89,10 @@ def get_valid_build_systems(working_dir, package=None): return clss -def create_build_system(working_dir, buildsys_type=None, package=None, opts=None, +def create_build_system(working_dir: str, buildsys_type: str | None = None, + package=None, opts=None, write_build_scripts=False, verbose=False, - build_args=[], child_build_args=[]): + build_args=[], child_build_args=[]) -> BuildSystem: """Return a new build system that can build the source in working_dir.""" from rez.plugin_managers import plugin_manager @@ -89,7 +112,7 @@ def create_build_system(working_dir, buildsys_type=None, package=None, opts=None buildsys_type = next(iter(clss)).name() # create instance of build system - cls_ = plugin_manager.get_plugin_class('build_system', buildsys_type) + cls_ = plugin_manager.get_plugin_class('build_system', buildsys_type, BuildSystem) return cls_(working_dir, opts=opts, @@ -104,12 +127,13 @@ class BuildSystem(object): """A build system, such as cmake, make, Scons etc. """ @classmethod - def name(cls): + def name(cls) -> str: """Return the name of the build system, eg 'make'.""" raise NotImplementedError - def __init__(self, working_dir, opts=None, package=None, - write_build_scripts=False, verbose=False, build_args=[], + def __init__(self, working_dir: str, opts=None, + package: DeveloperPackage | None = None, + write_build_scripts: bool = False, verbose: bool = False, build_args=[], child_build_args=[]): """Create a build system instance. @@ -143,12 +167,12 @@ def __init__(self, working_dir, opts=None, package=None, self.opts = opts @classmethod - def is_valid_root(cls, path): + def is_valid_root(cls, path: str, package=None) -> bool: """Return True if this build system can build the source in path.""" raise NotImplementedError @classmethod - def child_build_system(cls): + def child_build_system(cls) -> str | None: """Returns the child build system. Some build systems, such as cmake, don't build the source directly. @@ -163,19 +187,24 @@ def child_build_system(cls): return None @classmethod - def bind_cli(cls, parser, group): + def bind_cli(cls, parser: argparse.ArgumentParser, group: argparse._ArgumentGroup): """Expose parameters to an argparse.ArgumentParser that are specific to this build system. Args: parser (`ArgumentParser`): Arg parser. - group (`ArgumentGroup`): Arg parser group - you should add args to + group (`_ArgumentGroup`): Arg parser group - you should add args to this, NOT to `parser`. """ pass - def build(self, context, variant, build_path, install_path, install=False, - build_type=BuildType.local): + def build(self, + context: ResolvedContext, + variant: Variant, + build_path: str, + install_path: str, + install: bool = False, + build_type=BuildType.local) -> BuildResult: """Implement this method to perform the actual build. Args: @@ -206,8 +235,9 @@ def build(self, context, variant, build_path, install_path, install=False, raise NotImplementedError @classmethod - def set_standard_vars(cls, executor, context, variant, build_type, install, - build_path, install_path=None): + def set_standard_vars(cls, executor: RexExecutor, context: ResolvedContext, + variant: Variant, build_type: BuildType, install: bool, build_path: str, + install_path: str | None = None) -> None: """Set some standard env vars that all build systems can rely on. """ from rez.config import config @@ -294,8 +324,9 @@ def add_pre_build_commands(cls, executor, variant, build_type, install, executor.execute_code(pre_build_commands) @classmethod - def add_standard_build_actions(cls, executor, context, variant, build_type, - install, build_path, install_path=None): + def add_standard_build_actions(cls, executor: RexExecutor, context: ResolvedContext, variant: Variant, + build_type: BuildType, install: bool, build_path: str, + install_path: str | None = None) -> None: """Perform build actions common to every build system. """ diff --git a/src/rez/cli/_complete_util.py b/src/rez/cli/_complete_util.py index 3b34ca95a..9902d509b 100644 --- a/src/rez/cli/_complete_util.py +++ b/src/rez/cli/_complete_util.py @@ -93,9 +93,10 @@ def __call__(self, prefix, **kwargs): return [] matching_names = [] - names = (x for x in names if x.startswith(fileprefix)) for name in names: + if not name.startswith(fileprefix): + continue filepath = os.path.join(path, name) if os.path.isfile(filepath): if not self.files: diff --git a/src/rez/cli/_main.py b/src/rez/cli/_main.py index a5a9744f0..478dc7146 100644 --- a/src/rez/cli/_main.py +++ b/src/rez/cli/_main.py @@ -5,6 +5,8 @@ """ The main command-line entry point. """ +from __future__ import annotations + import sys import importlib from argparse import _StoreTrueAction, SUPPRESS @@ -165,7 +167,7 @@ def run(command=None): extra_arg_groups = [] if opts.debug or _env_var_true("REZ_DEBUG"): - exc_type = _NeverError + exc_type: type[RezError] = _NeverError else: exc_type = RezError diff --git a/src/rez/cli/_util.py b/src/rez/cli/_util.py index 917c27777..bf9caa3e1 100644 --- a/src/rez/cli/_util.py +++ b/src/rez/cli/_util.py @@ -2,11 +2,14 @@ # Copyright Contributors to the Rez Project +from __future__ import annotations + import os import sys import signal from argparse import _SubParsersAction, ArgumentParser, SUPPRESS, \ ArgumentError +from typing import Any # Subcommands and their behaviors. @@ -18,7 +21,7 @@ # The '--' arg is not treated as a special case. # * missing: Native python argparse behavior. # -subcommands = { +subcommands: dict[str, dict[str, Any]] = { "bind": {}, "build": { "arg_mode": "grouped" diff --git a/src/rez/cli/benchmark.py b/src/rez/cli/benchmark.py index a3a568dd6..01ce0b4c8 100644 --- a/src/rez/cli/benchmark.py +++ b/src/rez/cli/benchmark.py @@ -5,6 +5,8 @@ ''' Run a benchmarking suite for runtime resolves. ''' +from __future__ import annotations + import json import os import os.path @@ -17,8 +19,8 @@ # globals opts = None -out_dir = None -pkg_repo_dir = None +out_dir: str | None = None +pkg_repo_dir: str | None = None def setup_parser(parser, completions=False): @@ -47,6 +49,8 @@ def load_packages(): """ from rez.packages import iter_package_families + assert pkg_repo_dir is not None + print("Warming package cache...") fams = list(iter_package_families(paths=[pkg_repo_dir])) diff --git a/src/rez/cli/build.py b/src/rez/cli/build.py index f39057934..ade14e193 100644 --- a/src/rez/cli/build.py +++ b/src/rez/cli/build.py @@ -5,17 +5,23 @@ ''' Build a package from source. ''' +from __future__ import annotations + import os +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from rez.developer_package import DeveloperPackage # Cache the developer package loaded from cwd. This is so the package is only # loaded once, even though it's required once at arg parsing time (to determine # valid build system types), and once at command run time. # -_package = None +_package: DeveloperPackage | None = None -def get_current_developer_package(): +def get_current_developer_package() -> DeveloperPackage: from rez.packages import get_developer_package global _package diff --git a/src/rez/cli/complete.py b/src/rez/cli/complete.py index 8a10d4b15..90a343074 100644 --- a/src/rez/cli/complete.py +++ b/src/rez/cli/complete.py @@ -22,9 +22,9 @@ def command(opts, parser, extra_arg_groups=None): # get comp info from environment variables comp_line = os.getenv("COMP_LINE", "") - comp_point = os.getenv("COMP_POINT", "") + comp_point_str = os.getenv("COMP_POINT", "") try: - comp_point = int(comp_point) + comp_point = int(comp_point_str) except: comp_point = len(comp_line) @@ -60,7 +60,7 @@ def _pop_arg(l, p): cmds = [k for k, v in subcommands.items() if not v.get("hidden")] if prefix: - cmds = (x for x in cmds if x.startswith(prefix)) + cmds = [x for x in cmds if x.startswith(prefix)] print(" ".join(cmds)) if subcommand not in subcommands: diff --git a/src/rez/cli/interpret.py b/src/rez/cli/interpret.py index 6c3f4a4d1..1c4fe9d06 100644 --- a/src/rez/cli/interpret.py +++ b/src/rez/cli/interpret.py @@ -47,7 +47,6 @@ def command(opts, parser, extra_arg_groups=None): with open(opts.FILE) as f: code = f.read() - interp = None if opts.format is None: interp = create_shell() elif opts.format in ('dict', 'table'): diff --git a/src/rez/config.py b/src/rez/config.py index 56b4ca3f2..4fea8dcad 100644 --- a/src/rez/config.py +++ b/src/rez/config.py @@ -2,6 +2,8 @@ # Copyright Contributors to the Rez Project +from __future__ import annotations + from rez import __version__ from rez.utils.data_utils import AttrDictWrapper, RO_AttrDictWrapper, \ convert_dicts, cached_property, cached_class_property, LazyAttributeMeta, \ @@ -15,6 +17,7 @@ from rez.vendor.schema.schema import Schema, SchemaError, And, Or, Use from rez.vendor import yaml from rez.vendor.yaml.error import YAMLError +from rez.utils.typing import Protocol import rez.deprecations from contextlib import contextmanager from functools import lru_cache @@ -22,6 +25,12 @@ import os import re import copy +from typing import TYPE_CHECKING + + +class Validatable(Protocol): + def validate(self, data): + pass class _Deprecation(object): @@ -54,7 +63,7 @@ class Setting(object): Note that lazy setting validation only happens on main configuration settings - plugin settings are validated on load only. """ - schema = Schema(object) + schema: Validatable = Schema(object) def __init__(self, config, key): self.config = config @@ -135,7 +144,7 @@ def _validate(self, data): class Str(Setting): - schema = Schema(str) + schema: Validatable = Schema(str) def _parse_env_var(self, value): return value @@ -153,7 +162,7 @@ class OptionalStr(Str): class StrList(Setting): - schema = Schema([str]) + schema: Validatable = Schema([str]) sep = ',' def _parse_env_var(self, value): @@ -184,8 +193,7 @@ def validate(self, data): class OptionalStrList(StrList): - schema = Or(And(None, Use(lambda x: [])), - [str]) + schema = Or(And(None, Use(lambda x: [])), [str]) class PathList(StrList): @@ -219,7 +227,7 @@ def _parse_env_var(self, value): class Bool(Setting): - schema = Schema(bool) + schema: Validatable = Schema(bool) true_words = frozenset(["1", "true", "t", "yes", "y", "on"]) false_words = frozenset(["0", "false", "f", "no", "n", "off"]) all_words = true_words | false_words @@ -255,7 +263,7 @@ def _parse_env_var(self, value): class Dict(Setting): - schema = Schema(dict) + schema: Validatable = Schema(dict) def _parse_env_var(self, value): items = value.split(",") @@ -549,6 +557,13 @@ class Config(object, metaclass=LazyAttributeMeta): schema = config_schema schema_error = ConfigurationError + if TYPE_CHECKING: + # mypy: The use of LazyAttributeMeta means that this class generates hundreds + # of spurious attribute errors. Adding this for the type analysis will silence + # them until the use of LazyAttributeMeta can be addressed. + def __getattr__(self, item): + pass + def __init__(self, filepaths, overrides=None, locked=False): """Create a config. @@ -749,7 +764,7 @@ def _data(self): return data @classmethod - def _create_main_config(cls, overrides=None): + def _create_main_config(cls, overrides=None) -> Config: """See comment block at top of 'rezconfig' describing how the main config is assembled.""" filepaths = [] diff --git a/src/rez/deprecations.py b/src/rez/deprecations.py index 40b3f5e31..0cd47bc45 100644 --- a/src/rez/deprecations.py +++ b/src/rez/deprecations.py @@ -20,7 +20,7 @@ def warn(message, category, pre_formatted=False, stacklevel=1, filename=None, ** original_formatwarning = warnings.formatwarning if pre_formatted: - def formatwarning(_, category, *args, **kwargs): + def formatwarning(_, category, *args, **kwargs) -> str: return "{0}{1}: {2}\n".format( "{0}: ".format(filename) if filename else "", category.__name__, message ) diff --git a/src/rez/developer_package.py b/src/rez/developer_package.py index 820576f6d..241a64666 100644 --- a/src/rez/developer_package.py +++ b/src/rez/developer_package.py @@ -2,6 +2,8 @@ # Copyright Contributors to the Rez Project +from __future__ import annotations + from rez.config import config from rez.packages import Package, create_package from rez.serialise import load_from_file, FileFormat, set_objects @@ -44,7 +46,7 @@ def root(self): return None @classmethod - def from_path(cls, path, format=None): + def from_path(cls, path, format: FileFormat | None = None): """Load a developer package. A developer package may for example be a package.yaml or package.py in a @@ -62,9 +64,9 @@ def from_path(cls, path, format=None): data = None if format is None: - formats = (FileFormat.py, FileFormat.yaml) + formats = [FileFormat.py, FileFormat.yaml] else: - formats = (format,) + formats = [format] try: mode = os.stat(path).st_mode diff --git a/src/rez/package_bind.py b/src/rez/package_bind.py index be51c1a52..6a04a10cb 100644 --- a/src/rez/package_bind.py +++ b/src/rez/package_bind.py @@ -183,7 +183,7 @@ def _print_package_list(variants): packages = set([x.parent for x in variants]) packages = sorted(packages, key=lambda x: x.name) - rows = [["PACKAGE", "URI"], - ["-------", "---"]] + rows = [("PACKAGE", "URI"), + ("-------", "---")] rows += [(x.name, x.uri) for x in packages] print('\n'.join(columnise(rows))) diff --git a/src/rez/package_copy.py b/src/rez/package_copy.py index 353e1842d..fe70f01e4 100644 --- a/src/rez/package_copy.py +++ b/src/rez/package_copy.py @@ -2,6 +2,8 @@ # Copyright Contributors to the Rez Project +from __future__ import annotations + from functools import partial import os.path import shutil @@ -9,8 +11,8 @@ from rez.config import config from rez.exceptions import PackageCopyError -from rez.package_repository import package_repository_manager -from rez.packages import Variant +from rez.package_repository import package_repository_manager, PackageRepository +from rez.packages import Package, Variant from rez.serialise import FileFormat from rez.utils import with_noop from rez.utils.base26 import create_unique_base26_symlink @@ -20,10 +22,11 @@ safe_makedirs, additive_copytree, make_path_writable, get_existing_path -def copy_package(package, dest_repository, variants=None, shallow=False, - dest_name=None, dest_version=None, overwrite=False, force=False, - follow_symlinks=False, dry_run=False, keep_timestamp=False, - skip_payload=False, overrides=None, verbose=False): +def copy_package(package: Package, dest_repository: PackageRepository, + variants: list[int] | None = None, shallow: bool = False, + dest_name=None, dest_version=None, overwrite: bool = False, force: bool = False, + follow_symlinks: bool = False, dry_run: bool = False, keep_timestamp: bool = False, + skip_payload: bool = False, overrides=None, verbose: bool = False): """Copy a package from one package repository to another. This copies the package definition and payload. The package can also be @@ -227,8 +230,8 @@ def finalize(): return finalize() -def _copy_variant_payload(src_variant, dest_pkg_repo, shallow=False, - follow_symlinks=False, overrides=None, verbose=False): +def _copy_variant_payload(src_variant: Variant, dest_pkg_repo: PackageRepository, shallow: bool = False, + follow_symlinks: bool = False, overrides=None, verbose: bool = False): # Get payload path of source variant. For some types (eg from a "memory" # type repo) there may not be a root. # @@ -243,7 +246,7 @@ def _copy_variant_payload(src_variant, dest_pkg_repo, shallow=False, if not os.path.isdir(variant_root): raise PackageCopyError( "Cannot copy source variant %s - its root does not appear to " - "be present on disk (%s)." % src_variant.uri, variant_root + "be present on disk (%s)." % (src_variant.uri, variant_root) ) dest_variant_name = overrides.get("name") or src_variant.name diff --git a/src/rez/package_filter.py b/src/rez/package_filter.py index 70014e400..c6e64fb5c 100644 --- a/src/rez/package_filter.py +++ b/src/rez/package_filter.py @@ -2,12 +2,15 @@ # Copyright Contributors to the Rez Project +from __future__ import annotations + from rez.packages import iter_packages from rez.exceptions import ConfigurationError from rez.config import config from rez.utils.data_utils import cached_property, cached_class_property from rez.version import VersionedObject, Requirement from hashlib import sha1 +from typing import Pattern import fnmatch import re @@ -327,7 +330,8 @@ class Rule(object): """Base package filter rule""" #: Rule name - name = None + name: str + _family: str | None def match(self, package): """Apply the rule to the package. @@ -340,7 +344,7 @@ def match(self, package): """ raise NotImplementedError - def family(self): + def family(self) -> str | None: """Returns a package family string if this rule only applies to a given package family, otherwise None. @@ -365,11 +369,12 @@ def parse_rule(cls, txt): Returns: Rule: """ - types = {"glob": GlobRule, - "regex": RegexRule, - "range": RangeRule, - "before": TimestampRule, - "after": TimestampRule} + types: dict[str, type[Rule]] = { + "glob": GlobRule, + "regex": RegexRule, + "range": RangeRule, + "before": TimestampRule, + "after": TimestampRule} # parse form 'x(y)' into x, y label, txt = Rule._parse_label(txt) @@ -412,7 +417,7 @@ def _parse_label(cls, txt): return None, txt @classmethod - def _extract_family(cls, txt): + def _extract_family(cls, txt) -> str | None: m = cls.family_re.match(txt) if m: return m.group()[:-1] @@ -426,14 +431,17 @@ def __repr__(self): class RegexRuleBase(Rule): - def match(self, package): + regex: Pattern[str] + txt: str + + def match(self, package) -> bool: return bool(self.regex.match(package.qualified_name)) def cost(self): return 10 @classmethod - def _parse(cls, txt): + def _parse(cls, txt: str): _, txt = Rule._parse_label(txt) return cls(txt) @@ -448,7 +456,7 @@ class RegexRule(RegexRuleBase): """ name = "regex" - def __init__(self, s): + def __init__(self, s: str): """Create a regex rule. Args: @@ -466,7 +474,7 @@ class GlobRule(RegexRuleBase): """ name = "glob" - def __init__(self, s): + def __init__(self, s: str): """Create a glob rule. Args: diff --git a/src/rez/package_maker.py b/src/rez/package_maker.py index e547e0287..3b3b94069 100644 --- a/src/rez/package_maker.py +++ b/src/rez/package_maker.py @@ -2,6 +2,8 @@ # Copyright Contributors to the Rez Project +from __future__ import annotations + from rez.utils._version import _rez_version from rez.utils.schema import Required, extensible_schema_dict from rez.utils.filesystem import retain_cwd @@ -12,12 +14,13 @@ from rez.package_resources import help_schema, _commands_schema, \ _function_schema, late_bound from rez.package_repository import create_memory_package_repository -from rez.packages import Package +from rez.packages import Package, Variant from rez.package_py_utils import expand_requirement from rez.vendor.schema.schema import Schema, Optional, Or, Use, And from rez.version import Version from contextlib import contextmanager import os +from typing import Iterable # this schema will automatically harden request strings like 'python-*'; see @@ -92,7 +95,7 @@ class PackageMaker(AttrDictWrapper): """Utility class for creating packages.""" - def __init__(self, name, data=None, package_cls=None): + def __init__(self, name: str, data=None, package_cls: type[Package] | None = None): """Create a package maker. Args: @@ -106,7 +109,7 @@ def __init__(self, name, data=None, package_cls=None): self.installed_variants = [] self.skipped_variants = [] - def get_package(self): + def get_package(self) -> Package: """Create the analogous package. Returns: @@ -133,6 +136,7 @@ def get_package(self): # retrieve the package from the new repository family_resource = repo.get_package_family(self.name) + assert family_resource is not None it = repo.iter_packages(family_resource) package_resource = next(it) @@ -197,19 +201,20 @@ def make_package(name, path, make_base=None, make_root=None, skip_existing=True, # package = maker.get_package() - src_variants = [] # skip those variants that already exist if skip_existing: + variants: list[Variant] = [] for variant in package.iter_variants(): variant_ = variant.install(path, dry_run=True) if variant_ is None: - src_variants.append(variant) + variants.append(variant) else: maker.skipped_variants.append(variant_) if warn_on_skip: print_warning("Skipping installation: Package variant already " "exists: %s" % variant_.uri) + src_variants: Iterable[Variant] = variants else: src_variants = package.iter_variants() diff --git a/src/rez/package_order.py b/src/rez/package_order.py index e60662d77..de1d548fb 100644 --- a/src/rez/package_order.py +++ b/src/rez/package_order.py @@ -2,15 +2,23 @@ # Copyright Contributors to the Rez Project +from __future__ import annotations + from inspect import isclass from hashlib import sha1 -from typing import Dict, Iterable, List, Optional, Union +from typing import Any, Callable, Iterable, List, TYPE_CHECKING from rez.config import config from rez.utils.data_utils import cached_class_property from rez.version import Version, VersionRange from rez.version._version import _Comparable, _ReversedComparable, _LowerBound, _UpperBound, _Bound -from rez.packages import iter_packages +from rez.packages import iter_packages, Package +from rez.utils.typing import SupportsLessThan + +if TYPE_CHECKING: + # this is not available in typing until 3.11, but due to __future__.annotations + # we can use it without really importing it + from typing import Self ALL_PACKAGES = "*" @@ -20,17 +28,23 @@ class FallbackComparable(_Comparable): fails, compares using the fallback_comparable object. """ - def __init__(self, main_comparable, fallback_comparable): + def __init__(self, + main_comparable: SupportsLessThan, + fallback_comparable: SupportsLessThan): self.main_comparable = main_comparable self.fallback_comparable = fallback_comparable - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, FallbackComparable): + return NotImplemented try: return self.main_comparable == other.main_comparable except Exception: return self.fallback_comparable == other.fallback_comparable - def __lt__(self, other): + def __lt__(self, other: object) -> bool: + if not isinstance(other, FallbackComparable): + return NotImplemented try: return self.main_comparable < other.main_comparable except Exception: @@ -44,9 +58,10 @@ class PackageOrder(object): """Package reorderer base class.""" #: Orderer name - name = None + name: str + _packages: list[str] - def __init__(self, packages: Optional[Iterable[str]] = None): + def __init__(self, packages: Iterable[str] | None = None): """ Args: packages: If not provided, PackageOrder applies to all packages. @@ -54,7 +69,7 @@ def __init__(self, packages: Optional[Iterable[str]] = None): self.packages = packages @property - def packages(self) -> List[str]: + def packages(self) -> list[str]: """Returns an iterable over the list of package family names that this order applies to @@ -64,7 +79,7 @@ def packages(self) -> List[str]: return self._packages @packages.setter - def packages(self, packages: Union[str, Iterable[str]]): + def packages(self, packages: str | Iterable[str] | None): if packages is None: # Apply to all packages self._packages = [ALL_PACKAGES] @@ -73,7 +88,8 @@ def packages(self, packages: Union[str, Iterable[str]]): else: self._packages = sorted(packages) - def reorder(self, iterable, key=None): + def reorder(self, iterable: Iterable[Package], + key: Callable[[Any], Package] | None = None) -> list[Package] | None: """Put packages into some order for consumption. You can safely assume that the packages referred to by `iterable` are @@ -101,7 +117,9 @@ def reorder(self, iterable, key=None): reverse=True) @staticmethod - def _get_package_name_from_iterable(iterable, key=None): + def _get_package_name_from_iterable(iterable: Iterable[Package], + key: Callable[[Any], Package] | None = None + ) -> str | None: """Utility method for getting a package from an iterable""" try: item = next(iter(iterable)) @@ -111,7 +129,7 @@ def _get_package_name_from_iterable(iterable, key=None): key = key or (lambda x: x) return key(item).name - def sort_key(self, package_name, version_like): + def sort_key(self, package_name: str, version_like) -> SupportsLessThan: """Returns a sort key usable for sorting packages within the same family Args: @@ -148,7 +166,7 @@ def sort_key(self, package_name, version_like): return 0 raise TypeError(version_like) - def sort_key_implementation(self, package_name, version): + def sort_key_implementation(self, package_name: str, version: Version) -> SupportsLessThan: """Returns a sort key usable for sorting these packages within the same family Args: @@ -170,10 +188,10 @@ def from_pod(cls, data): raise NotImplementedError @property - def sha1(self): + def sha1(self) -> str: return sha1(repr(self).encode('utf-8')).hexdigest() - def __str__(self): + def __str__(self) -> str: raise NotImplementedError def __eq__(self, other): @@ -182,7 +200,7 @@ def __eq__(self, other): def __ne__(self, other): return not self == other - def __repr__(self): + def __repr__(self) -> str: return "%s(%s)" % (self.__class__.__name__, str(self)) @@ -195,12 +213,12 @@ class NullPackageOrder(PackageOrder): """ name = "no_order" - def sort_key_implementation(self, package_name, version): + def sort_key_implementation(self, package_name: str, version: Version) -> SupportsLessThan: # python's sort will preserve the order of items that compare equal, so # to not change anything, we just return the same object for all... return 0 - def __str__(self): + def __str__(self) -> str: return "{}" def __eq__(self, other): @@ -233,7 +251,7 @@ def __init__(self, descending, packages=None): super().__init__(packages) self.descending = descending - def sort_key_implementation(self, package_name, version): + def sort_key_implementation(self, package_name: str, version: Version) -> SupportsLessThan: # Note that the name "descending" can be slightly confusing - it # indicates that the final ordering this Order gives should be # version descending (ie, the default) - however, the sort_key itself @@ -246,7 +264,7 @@ def sort_key_implementation(self, package_name, version): else: return _ReversedComparable(version) - def __str__(self): + def __str__(self) -> str: return str(self.descending) def __eq__(self, other): @@ -283,7 +301,7 @@ class PerFamilyOrder(PackageOrder): """ name = "per_family" - def __init__(self, order_dict, default_order=None): + def __init__(self, order_dict: dict[str, PackageOrder], default_order=None): """Create a reorderer. Args: @@ -296,7 +314,8 @@ def __init__(self, order_dict, default_order=None): self.order_dict = order_dict.copy() self.default_order = default_order - def reorder(self, iterable, key=None): + def reorder(self, iterable: Iterable[Package], + key: Callable[[Any], Package] | None = None) -> list[Package] | None: package_name = self._get_package_name_from_iterable(iterable, key) if package_name is None: return None @@ -309,7 +328,7 @@ def reorder(self, iterable, key=None): return orderer.reorder(iterable, key) - def sort_key_implementation(self, package_name, version): + def sort_key_implementation(self, package_name: str, version: Version) -> SupportsLessThan: orderer = self.order_dict.get(package_name) if orderer is None: if self.default_order is None: @@ -322,7 +341,7 @@ def sort_key_implementation(self, package_name, version): return orderer.sort_key_implementation(package_name, version) - def __str__(self): + def __str__(self) -> str: items = sorted((x[0], str(x[1])) for x in self.order_dict.items()) return str((items, str(self.default_order))) @@ -402,7 +421,7 @@ class VersionSplitPackageOrder(PackageOrder): """ name = "version_split" - def __init__(self, first_version, packages=None): + def __init__(self, first_version: Version, packages=None): """Create a reorderer. Args: @@ -411,7 +430,7 @@ def __init__(self, first_version, packages=None): super().__init__(packages) self.first_version = first_version - def sort_key_implementation(self, package_name, version): + def sort_key_implementation(self, package_name: str, version: Version) -> SupportsLessThan: priority_key = 1 if version <= self.first_version else 0 return priority_key, version @@ -490,7 +509,7 @@ class TimestampPackageOrder(PackageOrder): """ name = "soft_timestamp" - def __init__(self, timestamp, rank=0, packages=None): + def __init__(self, timestamp: int, rank: int = 0, packages=None): """Create a reorderer. Args: @@ -555,7 +574,7 @@ def _calc_sort_key(self, package_name, version): first_after = self._get_first_after(package_name) if first_after is None: # all packages are before T - is_before = True + is_before: bool | int = True else: is_before = int(version < first_after) @@ -569,7 +588,7 @@ def _calc_sort_key(self, package_name, version): return is_before, _ReversedComparable(version) - def sort_key_implementation(self, package_name, version): + def sort_key_implementation(self, package_name: str, version: Version) -> SupportsLessThan: cache_key = (package_name, str(version)) result = self._cached_sort_key.get(cache_key) if result is None: @@ -578,7 +597,7 @@ def sort_key_implementation(self, package_name, version): return result - def __str__(self): + def __str__(self) -> str: return str((self.timestamp, self.rank)) def __eq__(self, other): @@ -614,13 +633,13 @@ def from_pod(cls, data): ) -class PackageOrderList(list): +class PackageOrderList(List[PackageOrder]): """A list of package orderer. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.by_package: Dict[str, PackageOrder] = {} + self.by_package: dict[str, PackageOrder] = {} self.dirty = True def to_pod(self): @@ -635,12 +654,12 @@ def from_pod(cls, data): return flist @cached_class_property - def singleton(cls): + def singleton(cls) -> Self: """Filter list as configured by rezconfig.package_filter.""" return cls.from_pod(config.package_orderers) @staticmethod - def _to_orderer(orderer: Union[dict, PackageOrder]) -> PackageOrder: + def _to_orderer(orderer: dict | PackageOrder) -> PackageOrder: if isinstance(orderer, dict): orderer = from_pod(orderer) return orderer @@ -657,31 +676,32 @@ def refresh(self) -> None: continue self.by_package[package] = orderer - def append(self, *args, **kwargs): - self.dirty = True - return super().append(*args, **kwargs) + if not TYPE_CHECKING: + def append(self, *args, **kwargs): + self.dirty = True + return super().append(*args, **kwargs) - def extend(self, *args, **kwargs): - self.dirty = True - return super().extend(*args, **kwargs) + def extend(self, *args, **kwargs): + self.dirty = True + return super().extend(*args, **kwargs) - def pop(self, *args, **kwargs): - self.dirty = True - return super().pop(*args, **kwargs) + def pop(self, *args, **kwargs): + self.dirty = True + return super().pop(*args, **kwargs) - def remove(self, *args, **kwargs): - self.dirty = True - return super().remove(*args, **kwargs) + def remove(self, *args, **kwargs): + self.dirty = True + return super().remove(*args, **kwargs) - def clear(self, *args, **kwargs): - self.dirty = True - return super().clear(*args, **kwargs) + def clear(self, *args, **kwargs): + self.dirty = True + return super().clear(*args, **kwargs) - def insert(self, *args, **kwargs): - self.dirty = True - return super().insert(*args, **kwargs) + def insert(self, *args, **kwargs): + self.dirty = True + return super().insert(*args, **kwargs) - def get(self, key: str, default: Optional[PackageOrder] = None) -> PackageOrder: + def get(self, key: str, default: PackageOrder | None = None) -> PackageOrder | None: """ Get an orderer that sorts a package by name. """ @@ -698,7 +718,7 @@ def to_pod(orderer): return data -def from_pod(data): +def from_pod(data) -> PackageOrder: if isinstance(data, dict): cls_name = data["type"] data = data.copy() diff --git a/src/rez/package_repository.py b/src/rez/package_repository.py index 2c7be5ab2..1a323593e 100644 --- a/src/rez/package_repository.py +++ b/src/rez/package_repository.py @@ -2,6 +2,8 @@ # Copyright Contributors to the Rez Project +from __future__ import annotations + from rez.utils.resources import ResourcePool, ResourceHandle from rez.utils.data_utils import cached_property from rez.plugin_managers import plugin_manager @@ -11,6 +13,14 @@ import threading import os.path import time +from typing import Any, Hashable, Iterator, TYPE_CHECKING + +if TYPE_CHECKING: + from rez.package_resources import (PackageFamilyResource, PackageResource, PackageResourceHelper, + VariantResource, PackageRepositoryResource) + from rez.utils.resources import Resource + from rez.version import Version + from rezplugins.package_repository.memory import MemoryPackageRepository def get_package_repository_types(): @@ -18,7 +28,7 @@ def get_package_repository_types(): return plugin_manager.get_plugins('package_repository') -def create_memory_package_repository(repository_data): +def create_memory_package_repository(repository_data: dict) -> MemoryPackageRepository: """Create a standalone in-memory package repository from the data given. See rezplugins/package_repository/memory.py for more details. @@ -29,7 +39,8 @@ def create_memory_package_repository(repository_data): Returns: `PackageRepository` object. """ - cls_ = plugin_manager.get_plugin_class("package_repository", "memory") + from rezplugins.package_repository.memory import MemoryPackageRepository # noqa + cls_ = plugin_manager.get_plugin_class("package_repository", "memory", MemoryPackageRepository) return cls_.create_repository(repository_data) @@ -69,11 +80,11 @@ class PackageRepository(object): remove = object() @classmethod - def name(cls): + def name(cls) -> str: """Return the name of the package repository type.""" raise NotImplementedError - def __init__(self, location, resource_pool): + def __init__(self, location: str, resource_pool: ResourcePool): """Create a package repository. Args: @@ -85,10 +96,10 @@ def __init__(self, location, resource_pool): self.location = location self.pool = resource_pool - def __str__(self): + def __str__(self) -> str: return "%s@%s" % (self.name(), self.location) - def register_resource(self, resource_class): + def register_resource(self, resource_class: type[Resource]) -> None: """Register a resource with the repository. Your derived repository class should call this method in its __init__ to @@ -96,12 +107,12 @@ def register_resource(self, resource_class): """ self.pool.register_resource(resource_class) - def clear_caches(self): + def clear_caches(self) -> None: """Clear any cached resources in the pool.""" self.pool.clear_caches() @cached_property - def uid(self): + def uid(self) -> tuple[str, str]: """Returns a unique identifier for this repository. This must be a persistent identifier, for example a filepath, or @@ -119,7 +130,7 @@ def __eq__(self, other): and other.uid == self.uid ) - def is_empty(self): + def is_empty(self) -> bool: """Determine if the repository contains any packages. Returns: @@ -131,7 +142,7 @@ def is_empty(self): return True - def get_package_family(self, name): + def get_package_family(self, name) -> PackageFamilyResource | None: """Get a package family. Args: @@ -142,7 +153,7 @@ def get_package_family(self, name): """ raise NotImplementedError - def iter_package_families(self): + def iter_package_families(self) -> Iterator[PackageFamilyResource]: """Iterate over the package families in the repository, in no particular order. @@ -151,7 +162,7 @@ def iter_package_families(self): """ raise NotImplementedError - def iter_packages(self, package_family_resource): + def iter_packages(self, package_family_resource) -> Iterator[PackageResource]: """Iterate over the packages within the given family, in no particular order. @@ -163,7 +174,7 @@ def iter_packages(self, package_family_resource): """ raise NotImplementedError - def iter_variants(self, package_resource): + def iter_variants(self, package_resource: PackageResource) -> Iterator[VariantResource]: """Iterate over the variants within the given package. Args: @@ -174,7 +185,7 @@ def iter_variants(self, package_resource): """ raise NotImplementedError - def get_package(self, name, version): + def get_package(self, name: str, version: Version) -> PackageResourceHelper | None: """Get a package. Args: @@ -182,7 +193,7 @@ def get_package(self, name, version): version (`Version`): Package version. Returns: - `PackageResource` or None: Matching package, or None if not found. + `PackageResourceHelper` or None: Matching package, or None if not found. """ fam = self.get_package_family(name) if fam is None: @@ -194,7 +205,7 @@ def get_package(self, name, version): return None - def get_package_from_uri(self, uri): + def get_package_from_uri(self, uri: str) -> PackageResource | None: """Get a package given its URI. Args: @@ -206,7 +217,7 @@ def get_package_from_uri(self, uri): """ return None - def get_variant_from_uri(self, uri): + def get_variant_from_uri(self, uri: str) -> VariantResource | None: """Get a variant given its URI. Args: @@ -218,7 +229,7 @@ def get_variant_from_uri(self, uri): """ return None - def ignore_package(self, pkg_name, pkg_version, allow_missing=False): + def ignore_package(self, pkg_name: str, pkg_version: Version, allow_missing=False) -> int: """Ignore the given package. Ignoring a package makes it invisible to further resolves. @@ -239,7 +250,7 @@ def ignore_package(self, pkg_name, pkg_version, allow_missing=False): """ raise NotImplementedError - def unignore_package(self, pkg_name, pkg_version): + def unignore_package(self, pkg_name: str, pkg_version: Version) -> int: """Unignore the given package. Args: @@ -254,7 +265,7 @@ def unignore_package(self, pkg_name, pkg_version): """ raise NotImplementedError - def remove_package(self, pkg_name, pkg_version): + def remove_package(self, pkg_name: str, pkg_version: Version) -> bool: """Remove a package. Note that this should work even if the specified package is currently @@ -269,7 +280,7 @@ def remove_package(self, pkg_name, pkg_version): """ raise NotImplementedError - def remove_package_family(self, pkg_name, force=False): + def remove_package_family(self, pkg_name: str, force: bool = False) -> bool: """Remove an empty package family. Args: @@ -281,7 +292,8 @@ def remove_package_family(self, pkg_name, force=False): """ raise NotImplementedError - def remove_ignored_since(self, days, dry_run=False, verbose=False): + def remove_ignored_since(self, days: int, dry_run: bool = False, + verbose: bool = False) -> int: """Remove packages ignored for >= specified number of days. Args: @@ -295,7 +307,7 @@ def remove_ignored_since(self, days, dry_run=False, verbose=False): """ raise NotImplementedError - def pre_variant_install(self, variant_resource): + def pre_variant_install(self, variant_resource: VariantResource): """Called before a variant is installed. If any directories are created on disk for the variant to install into, @@ -306,7 +318,7 @@ def pre_variant_install(self, variant_resource): """ pass - def on_variant_install_cancelled(self, variant_resource): + def on_variant_install_cancelled(self, variant_resource: VariantResource): """Called when a variant installation is cancelled. This is called after `pre_variant_install`, but before `install_variant`, @@ -321,7 +333,10 @@ def on_variant_install_cancelled(self, variant_resource): """ pass - def install_variant(self, variant_resource, dry_run=False, overrides=None): + def install_variant(self, + variant_resource: VariantResource, + dry_run: bool = False, + overrides: dict[str, Any] | None = None) -> VariantResource: """Install a variant into this repository. Use this function to install a variant from some other package repository @@ -343,7 +358,7 @@ def install_variant(self, variant_resource, dry_run=False, overrides=None): """ raise NotImplementedError - def get_equivalent_variant(self, variant_resource): + def get_equivalent_variant(self, variant_resource: VariantResource) -> VariantResource: """Find a variant in this repository that is equivalent to that given. A variant is equivalent to another if it belongs to a package of the @@ -362,7 +377,7 @@ def get_equivalent_variant(self, variant_resource): """ return self.install_variant(variant_resource, dry_run=True) - def get_parent_package_family(self, package_resource): + def get_parent_package_family(self, package_resource: PackageResourceHelper) -> PackageFamilyResource: """Get the parent package family of the given package. Args: @@ -373,7 +388,7 @@ def get_parent_package_family(self, package_resource): """ raise NotImplementedError - def get_parent_package(self, variant_resource): + def get_parent_package(self, variant_resource: VariantResource) -> PackageRepositoryResource: """Get the parent package of the given variant. Args: @@ -384,7 +399,8 @@ def get_parent_package(self, variant_resource): """ raise NotImplementedError - def get_variant_state_handle(self, variant_resource): + def get_variant_state_handle(self, variant_resource: PackageResource + ) -> Hashable | None: """Get a value that indicates the state of the variant. This is used for resolve caching. For example, in the 'filesystem' @@ -400,7 +416,8 @@ def get_variant_state_handle(self, variant_resource): """ return None - def get_last_release_time(self, package_family_resource): + def get_last_release_time(self, package_family_resource: PackageFamilyResource + ) -> int: """Get the last time a package was added to the given family. This information is used to cache resolves via memcached. It can be left @@ -414,7 +431,7 @@ def get_last_release_time(self, package_family_resource): """ return 0 - def make_resource_handle(self, resource_key, **variables): + def make_resource_handle(self, resource_key: str, **variables) -> ResourceHandle: """Create a `ResourceHandle` Nearly all `ResourceHandle` creation should go through here, because it @@ -438,7 +455,7 @@ def make_resource_handle(self, resource_key, **variables): variables = resource_cls.normalize_variables(variables) return ResourceHandle(resource_key, variables) - def get_resource(self, resource_key, **variables): + def get_resource(self, resource_key: str, **variables) -> Resource: """Get a resource. Attempts to get and return a cached version of the resource if @@ -454,7 +471,8 @@ def get_resource(self, resource_key, **variables): handle = self.make_resource_handle(resource_key, **variables) return self.get_resource_from_handle(handle, verify_repo=False) - def get_resource_from_handle(self, resource_handle, verify_repo=True): + def get_resource_from_handle(self, resource_handle: ResourceHandle, + verify_repo: bool = True) -> Resource: """Get a resource. Args: @@ -484,7 +502,7 @@ def get_resource_from_handle(self, resource_handle, verify_repo=True): resource._repository = self return resource - def get_package_payload_path(self, package_name, package_version=None): + def get_package_payload_path(self, package_name: str, package_version=None) -> str: """Defines where a package's payload should be installed to. Args: @@ -496,7 +514,7 @@ def get_package_payload_path(self, package_name, package_version=None): """ raise NotImplementedError - def _uid(self): + def _uid(self) -> tuple[str, str]: """Unique identifier implementation. You may need to provide your own implementation. For example, consider @@ -517,7 +535,7 @@ class PackageRepositoryManager(object): Manages retrieval of resources (packages and variants) from `PackageRepository` instances, and caches these resources in a resource pool. """ - def __init__(self, resource_pool=None): + def __init__(self, resource_pool: ResourcePool | None = None): """Create a package repo manager. Args: @@ -532,9 +550,9 @@ def __init__(self, resource_pool=None): resource_pool = ResourcePool(cache_size=cache_size) self.pool = resource_pool - self.repositories = {} + self.repositories: dict[str, PackageRepository] = {} - def get_repository(self, path): + def get_repository(self, path: str) -> PackageRepository: """Get a package repository. Args: @@ -550,9 +568,10 @@ def get_repository(self, path): # normalise repo path parts = path.split('@', 1) if len(parts) == 1: - parts = ("filesystem", parts[0]) + repo_type, location = ("filesystem", parts[0]) + else: + repo_type, location = parts - repo_type, location = parts if repo_type == "filesystem": # choice of abspath here vs realpath is deliberate. Realpath gives # canonical path, which can be a problem if two studios are sharing @@ -573,7 +592,7 @@ def get_repository(self, path): return repository - def are_same(self, path_1, path_2): + def are_same(self, path_1, path_2) -> bool: """Test that `path_1` and `path_2` refer to the same repository. This is more reliable than testing that the strings match, since slightly @@ -590,8 +609,8 @@ def are_same(self, path_1, path_2): repo_2 = self.get_repository(path_2) return (repo_1.uid == repo_2.uid) - def get_resource(self, resource_key, repository_type, location, - **variables): + def get_resource(self, resource_key: str, repository_type: str, + location: str, **variables) -> Resource: """Get a resource. Attempts to get and return a cached version of the resource if @@ -612,7 +631,8 @@ def get_resource(self, resource_key, repository_type, location, resource = repo.get_resource(**variables) return resource - def get_resource_from_handle(self, resource_handle): + def get_resource_from_handle(self, resource_handle: ResourceHandle + ) -> Resource: """Get a resource. Args: @@ -632,14 +652,14 @@ def get_resource_from_handle(self, resource_handle): resource = repo.get_resource_from_handle(resource_handle) return resource - def clear_caches(self): + def clear_caches(self) -> None: """Clear all cached data.""" self.repositories.clear() self.pool.clear_caches() - def _get_repository(self, path, **repo_args): + def _get_repository(self, path: str, **repo_args) -> PackageRepository: repo_type, location = path.split('@', 1) - cls = plugin_manager.get_plugin_class('package_repository', repo_type) + cls = plugin_manager.get_plugin_class('package_repository', repo_type, PackageRepository) repo = cls(location, self.pool, **repo_args) return repo diff --git a/src/rez/package_resources.py b/src/rez/package_resources.py index f33c4270a..4d475058e 100644 --- a/src/rez/package_resources.py +++ b/src/rez/package_resources.py @@ -2,6 +2,8 @@ # Copyright Contributors to the Rez Project +from __future__ import annotations + from rez.utils.resources import Resource from rez.utils.schema import Required, schema_keys, extensible_schema_dict from rez.utils.logging_ import print_warning @@ -12,12 +14,18 @@ from rez.utils.formatting import PackageRequest from rez.exceptions import PackageMetadataError, ResourceError from rez.config import config, Config, create_config -from rez.version import Version +from rez.version import Requirement, Version from rez.vendor.schema.schema import Schema, SchemaError, Optional, Or, And, Use from textwrap import dedent import os.path +from abc import abstractmethod from hashlib import sha1 +from typing import Any, Iterable, Iterator, TYPE_CHECKING +from types import FunctionType, MethodType + +if TYPE_CHECKING: + from rez.packages import Variant # package attributes created at release time @@ -76,7 +84,7 @@ def late_bound(schema): # requirements of all package-related resources # -base_resource_schema_dict = { +base_resource_schema_dict: dict[Schema, Any] = { Required("name"): str } @@ -269,7 +277,7 @@ class PackageRepositoryResource(Resource): """ schema_error = PackageMetadataError #: Type of package repository associated with this resource type. - repository_type = None + repository_type: str @classmethod def normalize_variables(cls, variables): @@ -284,18 +292,18 @@ def __init__(self, variables=None): super(PackageRepositoryResource, self).__init__(variables) @cached_property - def uri(self): + def uri(self) -> str: return self._uri() @property - def location(self): + def location(self) -> str | None: return self.get("location") @property - def name(self): + def name(self) -> str | None: return self.get("name") - def _uri(self): + def _uri(self) -> str: """Return a URI. Implement this function to return a short, readable string that @@ -310,7 +318,9 @@ class PackageFamilyResource(PackageRepositoryResource): A repository implementation's package family resource(s) must derive from this class. It must satisfy the schema `package_family_schema`. """ - pass + + def iter_packages(self) -> Iterator[PackageResourceHelper]: + raise NotImplementedError class PackageResource(PackageRepositoryResource): @@ -330,7 +340,7 @@ def normalize_variables(cls, variables): return super(PackageResource, cls).normalize_variables(variables) @cached_property - def version(self): + def version(self) -> Version: ver_str = self.get("version", "") return Version(ver_str) @@ -345,17 +355,23 @@ class VariantResource(PackageResource): this case it is the 'None' variant (the value of `index` is None). This provides some internal consistency and simplifies the implementation. """ + + @property + @abstractmethod + def parent(self) -> PackageRepositoryResource: + raise NotImplementedError + @property - def index(self): + def index(self) -> int | None: return self.get("index", None) @cached_property - def root(self): + def root(self) -> str: """Return the 'root' path of the variant.""" return self._root() @cached_property - def subpath(self): + def subpath(self) -> str: """Return the variant's 'subpath' The subpath is the relative path the variant's payload should be stored @@ -364,9 +380,11 @@ def subpath(self): """ return self._subpath() + @abstractmethod def _root(self, ignore_shortlinks=False): raise NotImplementedError + @abstractmethod def _subpath(self, ignore_shortlinks=False): raise NotImplementedError @@ -381,25 +399,43 @@ def _subpath(self, ignore_shortlinks=False): class PackageResourceHelper(PackageResource): """PackageResource with some common functionality included. """ - variant_key = None + # the resource key for a VariantResourceHelper subclass + variant_key: str + + if TYPE_CHECKING: + # I think these attributes are provided dynamically be LazyAttributeMeta + _commands: list[str] | str | FunctionType | MethodType | SourceCode + _pre_commands: list[str] | str | FunctionType | MethodType | SourceCode + _post_commands: list[str] | str | FunctionType | MethodType | SourceCode + variants: list[Variant] + + @property + @abstractmethod + def base(self) -> str | None: + raise NotImplementedError + + @property + @abstractmethod + def parent(self) -> PackageRepositoryResource: + raise NotImplementedError @cached_property - def commands(self): + def commands(self) -> SourceCode: return self._convert_to_rex(self._commands) @cached_property - def pre_commands(self): + def pre_commands(self) -> SourceCode: return self._convert_to_rex(self._pre_commands) @cached_property - def post_commands(self): + def post_commands(self) -> SourceCode: return self._convert_to_rex(self._post_commands) - def iter_variants(self): + def iter_variants(self) -> Iterator[VariantResourceHelper]: num_variants = len(self.variants or []) if num_variants == 0: - indexes = [None] + indexes: Iterable[int | None] = [None] else: indexes = range(num_variants) @@ -412,7 +448,7 @@ def iter_variants(self): index=index) yield variant - def _convert_to_rex(self, commands): + def _convert_to_rex(self, commands: list[str] | str | FunctionType | MethodType | SourceCode) -> SourceCode: if isinstance(commands, list): from rez.utils.backcompat import convert_old_commands @@ -453,12 +489,12 @@ class VariantResourceHelper(VariantResource, metaclass=_Metas): # forward Package attributes onto ourself keys = schema_keys(package_schema) - set(["variants"]) - def _uri(self): + def _uri(self) -> str: index = self.index idxstr = '' if index is None else str(index) return "%s[%s]" % (self.parent.uri, idxstr) - def _subpath(self, ignore_shortlinks=False): + def _subpath(self, ignore_shortlinks=False) -> str | None: if self.index is None: return None @@ -488,7 +524,7 @@ def _subpath(self, ignore_shortlinks=False): subpath = os.path.join(*dirs) return subpath - def _root(self, ignore_shortlinks=False): + def _root(self, ignore_shortlinks: bool = False) -> str | None: if self.base is None: return None elif self.index is None: @@ -499,7 +535,7 @@ def _root(self, ignore_shortlinks=False): return root @cached_property - def variant_requires(self): + def variant_requires(self) -> list[Requirement]: index = self.index if index is None: return [] diff --git a/src/rez/package_test.py b/src/rez/package_test.py index 3c85882d1..8159aeb1c 100644 --- a/src/rez/package_test.py +++ b/src/rez/package_test.py @@ -171,7 +171,7 @@ def _select(value): ) if ran_once: - def _select(key, value): + def _select_kv(key, value): if isinstance(value, dict): value = value.get("on_variants") else: @@ -184,7 +184,7 @@ def _select(key, value): tests_dict = dict( (k, v) for k, v in tests_dict.items() - if _select(k, v) + if _select_kv(k, v) ) return sorted(tests_dict.keys()) @@ -514,8 +514,8 @@ def _on_variant_requires(self, variant, params): # If the combined requirements, minus conflict requests, is equal to the # variant's requirements, then this variant is selected. # - reqs1 = RequirementList(x for x in reqlist if not x.conflict) - reqs2 = RequirementList(x for x in variant.variant_requires if not x.conflict) + reqs1 = RequirementList([x for x in reqlist if not x.conflict]) + reqs2 = RequirementList([x for x in variant.variant_requires if not x.conflict]) return (reqs1 == reqs2) def _get_test_info(self, test_name, variant): diff --git a/src/rez/packages.py b/src/rez/packages.py index dc816303b..39fb1241d 100644 --- a/src/rez/packages.py +++ b/src/rez/packages.py @@ -2,6 +2,8 @@ # Copyright Contributors to the Rez Project +from __future__ import annotations + from rez.package_repository import package_repository_manager from rez.package_resources import PackageFamilyResource, PackageResource, \ VariantResource, package_family_schema, package_schema, variant_schema, \ @@ -21,6 +23,16 @@ import os import sys +from typing import overload, Any, Iterator, TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Literal # not available in typing module until 3.8 + from rez.config import Config + from rez.developer_package import DeveloperPackage + from rez.version import Requirement + from rez.package_repository import PackageRepository + from rez.resolved_context import ResolvedContext + # ------------------------------------------------------------------------------ # package-related classes @@ -36,7 +48,7 @@ def validated_data(self): return data @property - def repository(self): + def repository(self) -> PackageRepository: """The package repository this resource comes from. Returns: @@ -54,11 +66,11 @@ class PackageFamily(PackageRepositoryResourceWrapper): """ keys = schema_keys(package_family_schema) - def __init__(self, resource): + def __init__(self, resource: PackageFamilyResource): _check_class(resource, PackageFamilyResource) super(PackageFamily, self).__init__(resource) - def iter_packages(self): + def iter_packages(self) -> Iterator[Package]: """Iterate over the packages within this family, in no particular order. Returns: @@ -75,14 +87,14 @@ class PackageBaseResourceWrapper(PackageRepositoryResourceWrapper): "requires": late_requires_schema } - def __init__(self, resource, context=None): + def __init__(self, resource: PackageResource | VariantResource, context: ResolvedContext | None = None): super(PackageBaseResourceWrapper, self).__init__(resource) self.context = context # cached results of late-bound funcs self._late_binding_returnvalues = {} - def set_context(self, context): + def set_context(self, context: ResolvedContext): self.context = context def arbitrary_keys(self): @@ -93,7 +105,7 @@ def uri(self): return self.resource.uri @property - def config(self): + def config(self) -> Config: """Returns the config for this package. Defaults to global config if this package did not provide a 'config' @@ -102,7 +114,7 @@ def config(self): return self.resource.config or config @cached_property - def is_local(self): + def is_local(self) -> bool: """Returns True if the package is in the local package repository""" local_repo = package_repository_manager.get_repository( self.config.local_packages_path) @@ -160,8 +172,8 @@ def _wrap_forwarded(self, key, value): else: return value - def _eval_late_binding(self, sourcecode): - g = {} + def _eval_late_binding(self, sourcecode: SourceCode): + g: dict[str, Any] = {} if self.context is None: g["in_context"] = lambda: False @@ -200,7 +212,7 @@ class Package(PackageBaseResourceWrapper): #: funcs, where ``this`` may be a package or variant. is_variant = False - def __init__(self, resource, context=None): + def __init__(self, resource: PackageResource, context=None): _check_class(resource, PackageResource) super(Package, self).__init__(resource, context) @@ -223,7 +235,7 @@ def arbitrary_keys(self): return set(self.data.keys()) - set(self.keys) @cached_property - def qualified_name(self): + def qualified_name(self) -> str: """Get the qualified name of the package. Returns: @@ -232,7 +244,7 @@ def qualified_name(self): o = VersionedObject.construct(self.name, self.version) return str(o) - def as_exact_requirement(self): + def as_exact_requirement(self) -> str: """Get the package, as an exact requirement string. Returns: @@ -242,7 +254,7 @@ def as_exact_requirement(self): return o.as_exact_requirement() @cached_property - def parent(self): + def parent(self) -> PackageFamily | None: """Get the parent package family. Returns: @@ -252,11 +264,11 @@ def parent(self): return PackageFamily(family) if family else None @cached_property - def num_variants(self): + def num_variants(self) -> int: return len(self.data.get("variants", [])) @property - def is_relocatable(self): + def is_relocatable(self) -> bool: """True if the package and its payload is safe to copy. """ if self.relocatable is not None: @@ -276,7 +288,7 @@ def is_relocatable(self): return config.default_relocatable @property - def is_cachable(self): + def is_cachable(self) -> bool: """True if the package and its payload is safe to cache locally. """ if self.cachable is not None: @@ -301,7 +313,7 @@ def is_cachable(self): return self.is_relocatable - def iter_variants(self): + def iter_variants(self) -> Iterator[Variant]: """Iterate over the variants within this package, in index order. Returns: @@ -310,7 +322,7 @@ def iter_variants(self): for variant in self.repository.iter_variants(self.resource): yield Variant(variant, context=self.context, parent=self) - def get_variant(self, index=None): + def get_variant(self, index=None) -> Variant | None: """Get the variant with the associated index. Returns: @@ -319,6 +331,7 @@ def get_variant(self, index=None): for variant in self.iter_variants(): if variant.index == index: return variant + return None class Variant(PackageBaseResourceWrapper): @@ -337,7 +350,7 @@ class Variant(PackageBaseResourceWrapper): #: See :attr:`Package.is_variant`. is_variant = True - def __init__(self, resource, context=None, parent=None): + def __init__(self, resource: VariantResource, context=None, parent=None): _check_class(resource, VariantResource) super(Variant, self).__init__(resource, context) self._parent = parent @@ -353,12 +366,12 @@ def arbitrary_keys(self): return self.parent.arbitrary_keys() @cached_property - def qualified_package_name(self): + def qualified_package_name(self) -> str: o = VersionedObject.construct(self.name, self.version) return str(o) @cached_property - def qualified_name(self): + def qualified_name(self) -> str: """Get the qualified name of the variant. Returns: @@ -368,7 +381,7 @@ def qualified_name(self): return "%s[%s]" % (self.qualified_package_name, idxstr) @cached_property - def parent(self): + def parent(self) -> Package: """Get the parent package. Returns: @@ -386,7 +399,7 @@ def parent(self): return self._parent @property - def variant_requires(self): + def variant_requires(self) -> list[Requirement]: """Get the subset of requirements specific to this variant. Returns: @@ -398,7 +411,7 @@ def variant_requires(self): return self.parent.variants[self.index] or [] @property - def requires(self): + def requires(self) -> list[Requirement]: """Get variant requirements. This is a concatenation of the package requirements and those of this @@ -411,7 +424,8 @@ def requires(self): (self.parent.requires or []) + self.variant_requires ) - def get_requires(self, build_requires=False, private_build_requires=False): + def get_requires(self, build_requires=False, private_build_requires=False + ) -> list[Requirement]: """Get the requirements of the variant. Args: @@ -431,7 +445,7 @@ def get_requires(self, build_requires=False, private_build_requires=False): return requires - def install(self, path, dry_run=False, overrides=None): + def install(self, path, dry_run=False, overrides=None) -> Variant: """Install this variant into another package repository. If the package already exists, this variant will be correctly merged @@ -518,7 +532,7 @@ def _repository_uids(self): # resource acquisition functions # ------------------------------------------------------------------------------ -def iter_package_families(paths=None): +def iter_package_families(paths: list[str] | None = None): """Iterate over package families, in no particular order. Note that multiple package families with the same name can be returned. @@ -538,7 +552,8 @@ def iter_package_families(paths=None): yield PackageFamily(resource) -def iter_packages(name, range_=None, paths=None): +def iter_packages(name: str, range_: VersionRange | str | None = None, + paths: list[str] | None = None) -> Iterator[Package]: """Iterate over `Package` instances, in no particular order. Packages of the same name and version earlier in the search path take @@ -574,7 +589,7 @@ def iter_packages(name, range_=None, paths=None): yield Package(package_resource) -def get_package(name, version, paths=None): +def get_package(name: str, version: Version | str, paths: list[str] | None = None) -> Package | None: """Get a package by searching a list of repositories. Args: @@ -598,7 +613,7 @@ def get_package(name, version, paths=None): return None -def get_package_family_from_repository(name, path): +def get_package_family_from_repository(name: str, path: str): """Get a package family from a repository. Args: @@ -616,7 +631,7 @@ def get_package_family_from_repository(name, path): return PackageFamily(family_resource) -def get_package_from_repository(name, version, path): +def get_package_from_repository(name: str, version, path: str): """Get a package from a repository. Args: @@ -656,7 +671,7 @@ def get_package_from_handle(package_handle): return package -def get_package_from_string(txt, paths=None): +def get_package_from_string(txt: str, paths: list[str] | None = None): """Get a package given a string. Args: @@ -671,12 +686,12 @@ def get_package_from_string(txt, paths=None): return get_package(o.name, o.version, paths=paths) -def get_developer_package(path, format=None): +def get_developer_package(path: str, format: FileFormat | None = None) -> DeveloperPackage: """Create a developer package. Args: path (str): Path to dir containing package definition file. - format (str): Package definition file format, detected if None. + format (FileFormat): Package definition file format, detected if None. Returns: `DeveloperPackage`. @@ -685,7 +700,7 @@ def get_developer_package(path, format=None): return DeveloperPackage.from_path(path, format=format) -def create_package(name, data, package_cls=None): +def create_package(name: str, data, package_cls=None): """Create a package given package data. Args: @@ -700,7 +715,7 @@ def create_package(name, data, package_cls=None): return maker.get_package() -def get_variant(variant_handle, context=None): +def get_variant(variant_handle: ResourceHandle | dict, context=None) -> Variant: """Create a variant given its handle (or serialized dict equivalent) Args: @@ -721,7 +736,7 @@ def get_variant(variant_handle, context=None): return variant -def get_package_from_uri(uri, paths=None): +def get_package_from_uri(uri: str, paths: list[str] | None = None) -> Package | None: """Get a package given its URI. Args: @@ -768,7 +783,7 @@ def _find_in_path(path): return _find_in_path(path) -def get_variant_from_uri(uri, paths=None): +def get_variant_from_uri(uri: str, paths: list[str] | None = None) -> Variant | None: """Get a variant given its URI. Args: @@ -822,7 +837,7 @@ def _find_in_path(path): return _find_in_path(path) -def get_last_release_time(name, paths=None): +def get_last_release_time(name: str, paths: list[str] | None = None) -> int: """Returns the most recent time this package was released. Note that releasing a variant into an already-released package is also @@ -848,7 +863,7 @@ def get_last_release_time(name, paths=None): return max_time -def get_completions(prefix, paths=None, family_only=False): +def get_completions(prefix: str, paths: list[str] | None = None, family_only=False): """Get autocompletion options given a prefix string. Example: @@ -904,7 +919,23 @@ def get_completions(prefix, paths=None, family_only=False): return words -def get_latest_package(name, range_=None, paths=None, error=False): +@overload +def get_latest_package(name: str, *, range_=None, + paths: list[str] | None = None, + error: Literal[True] = True) -> Package: + pass + + +@overload +def get_latest_package(name: str, *, range_=None, + paths: list[str] | None = None, + error: Literal[False] = False) -> Package | None: + pass + + +def get_latest_package(name: str, *, range_=None, + paths: list[str] | None = None, + error: bool = False) -> Package | None: """Get the latest package for a given package name. Args: @@ -928,7 +959,7 @@ def get_latest_package(name, range_=None, paths=None, error=False): return None -def get_latest_package_from_string(txt, paths=None, error=False): +def get_latest_package_from_string(txt: str, paths: list[str] | None = None, error=False): """Get the latest package found within the given request string. Args: @@ -949,7 +980,8 @@ def get_latest_package_from_string(txt, paths=None, error=False): error=error) -def _get_families(name, paths=None): +def _get_families(name: str, paths: list[str] | None = None + ) -> list[tuple[PackageRepository, PackageFamilyResource]]: entries = [] for path in (paths or config.packages_path): repo = package_repository_manager.get_repository(path) diff --git a/src/rez/pip.py b/src/rez/pip.py index 1703bb567..bcd072cb8 100644 --- a/src/rez/pip.py +++ b/src/rez/pip.py @@ -43,7 +43,7 @@ class InstallMode(Enum): min_deps = 1 -def run_pip_command(command_args, pip_version=None, python_version=None): +def run_pip_command(command_args, pip_version=None, python_version=None) -> Popen: """Run a pip command. Args: command_args (list of str): Args to pip. diff --git a/src/rez/plugin_managers.py b/src/rez/plugin_managers.py index 64b1d8f09..dc0c7a933 100644 --- a/src/rez/plugin_managers.py +++ b/src/rez/plugin_managers.py @@ -5,6 +5,8 @@ """ Manages loading of all types of Rez plugins. """ +from __future__ import annotations + from rez.config import config, expand_system_vars, _load_config_from_filepaths from rez.utils.formatting import columnise from rez.utils.schema import dict_to_schema @@ -12,9 +14,13 @@ from rez.utils.logging_ import print_debug, print_warning from rez.exceptions import RezPluginError from zipimport import zipimporter +from typing import overload, Any, TypeVar import pkgutil import os.path import sys +import types + +T = TypeVar("T") # modified from pkgutil standard library: @@ -85,30 +91,30 @@ class RezPluginType(object): 'type_name' must correspond with one of the source directories found under the 'plugins' directory. """ - type_name = None + type_name: str def __init__(self): if self.type_name is None: raise TypeError("Subclasses of RezPluginType must provide a " "'type_name' attribute") self.pretty_type_name = self.type_name.replace('_', ' ') - self.plugin_classes = {} - self.failed_plugins = {} - self.plugin_modules = {} + self.plugin_classes: dict[str, type] = {} + self.failed_plugins: dict[str, str] = {} + self.plugin_modules: dict[str, types.ModuleType] = {} self.config_data = {} self.load_plugins() def __repr__(self): return '%s(%s)' % (self.__class__.__name__, self.plugin_classes.keys()) - def register_plugin(self, plugin_name, plugin_class, plugin_module): + def register_plugin(self, plugin_name: str, plugin_class: type, plugin_module: types.ModuleType) -> None: # TODO: check plugin_class to ensure it is a sub-class of expected base-class? # TODO: perhaps have a Plugin base class. This introduces multiple # inheritance in Shell class though :/ self.plugin_classes[plugin_name] = plugin_class self.plugin_modules[plugin_name] = plugin_module - def load_plugins(self): + def load_plugins(self) -> None: import pkgutil from importlib import import_module type_module_name = 'rezplugins.' + self.type_name @@ -205,7 +211,7 @@ def load_plugins(self): data, _ = _load_config_from_filepaths([os.path.join(path, "rezconfig")]) deep_update(self.config_data, data) - def get_plugin_class(self, plugin_name): + def get_plugin_class(self, plugin_name: str) -> type: """Returns the class registered under the given plugin name.""" try: return self.plugin_classes[plugin_name] @@ -213,7 +219,7 @@ def get_plugin_class(self, plugin_name): raise RezPluginError("Unrecognised %s plugin: '%s'" % (self.pretty_type_name, plugin_name)) - def get_plugin_module(self, plugin_name): + def get_plugin_module(self, plugin_name: str) -> types.ModuleType: """Returns the module containing the plugin of the given name.""" try: return self.plugin_modules[plugin_name] @@ -235,7 +241,7 @@ def config_schema(self): deep_update(d, d_) return dict_to_schema(d, required=True, modifier=expand_system_vars) - def create_instance(self, plugin, **instance_kwargs): + def create_instance(self, plugin: str, **instance_kwargs) -> Any: """Create and return an instance of the given plugin.""" return self.get_plugin_class(plugin)(**instance_kwargs) @@ -294,7 +300,7 @@ def register_plugin(): 'rezplugins' is always found first. """ def __init__(self): - self._plugin_types = {} + self._plugin_types: dict[str, LazySingleton[RezPluginType]] = {} @cached_property def rezplugins_module_paths(self): @@ -329,14 +335,14 @@ def rezplugins_module_paths(self): # -- plugin types - def _get_plugin_type(self, plugin_type): + def _get_plugin_type(self, plugin_type: str) -> RezPluginType: try: return self._plugin_types[plugin_type]() except KeyError: raise RezPluginError("Unrecognised plugin type: '%s'" % plugin_type) - def register_plugin_type(self, type_class): + def register_plugin_type(self, type_class: type[RezPluginType]) -> None: if not issubclass(type_class, RezPluginType): raise TypeError("'type_class' must be a RezPluginType sub class") if type_class.type_name is None: @@ -344,38 +350,50 @@ def register_plugin_type(self, type_class): "'type_name' attribute") self._plugin_types[type_class.type_name] = LazySingleton(type_class) - def get_plugin_types(self): + def get_plugin_types(self) -> list[str]: """Return a list of the registered plugin types.""" - return self._plugin_types.keys() + return list(self._plugin_types.keys()) # -- plugins - def get_plugins(self, plugin_type): + def get_plugins(self, plugin_type: str) -> list[str]: """Return a list of the registered names available for the given plugin type.""" - return self._get_plugin_type(plugin_type).plugin_classes.keys() + return list(self._get_plugin_type(plugin_type).plugin_classes.keys()) + + @overload + def get_plugin_class(self, plugin_type: str, plugin_name: str) -> type: + pass + + @overload + def get_plugin_class(self, plugin_type: str, plugin_name: str, expected_type: type[T]) -> type[T]: + pass - def get_plugin_class(self, plugin_type, plugin_name): + def get_plugin_class(self, plugin_type: str, plugin_name: str, expected_type: type | None = None) -> type: """Return the class registered under the given plugin name.""" plugin = self._get_plugin_type(plugin_type) - return plugin.get_plugin_class(plugin_name) + cls = plugin.get_plugin_class(plugin_name) + if expected_type is not None and not isinstance(cls, type) and issubclass(cls, expected_type): + raise RezPluginError("%s: Plugin class for %s was not the expected type: %s != %s" + % (plugin.pretty_type_name, plugin_name, cls, expected_type)) + return cls - def get_plugin_module(self, plugin_type, plugin_name): + def get_plugin_module(self, plugin_type: str, plugin_name: str) -> types.ModuleType: """Return the module defining the class registered under the given plugin name.""" plugin = self._get_plugin_type(plugin_type) return plugin.get_plugin_module(plugin_name) - def get_plugin_config_data(self, plugin_type): + def get_plugin_config_data(self, plugin_type: str): """Return the merged configuration data for the plugin type.""" plugin = self._get_plugin_type(plugin_type) return plugin.config_data - def get_plugin_config_schema(self, plugin_type): + def get_plugin_config_schema(self, plugin_type: str): plugin = self._get_plugin_type(plugin_type) return plugin.config_schema - def get_failed_plugins(self, plugin_type): + def get_failed_plugins(self, plugin_type: str) -> list[tuple[str, str]]: """Return a list of plugins for the given type that failed to load. Returns: @@ -383,14 +401,14 @@ def get_failed_plugins(self, plugin_type): name (str): Name of the plugin. reason (str): Error message. """ - return self._get_plugin_type(plugin_type).failed_plugins.items() + return list(self._get_plugin_type(plugin_type).failed_plugins.items()) - def create_instance(self, plugin_type, plugin_name, **instance_kwargs): + def create_instance(self, plugin_type: str, plugin_name, **instance_kwargs) -> Any: """Create and return an instance of the given plugin.""" plugin_type = self._get_plugin_type(plugin_type) return plugin_type.create_instance(plugin_name, **instance_kwargs) - def get_summary_string(self): + def get_summary_string(self) -> str: """Get a formatted string summarising the plugins that were loaded.""" rows = [["PLUGIN TYPE", "NAME", "DESCRIPTION", "STATUS"], ["-----------", "----", "-----------", "------"]] diff --git a/src/rez/release_vcs.py b/src/rez/release_vcs.py index dc232c402..7d7754eda 100644 --- a/src/rez/release_vcs.py +++ b/src/rez/release_vcs.py @@ -25,12 +25,12 @@ def create_release_vcs(path, vcs_name=None): if vcs_name: if vcs_name not in vcs_types: raise ReleaseVCSError("Unknown version control system: %r" % vcs_name) - cls = plugin_manager.get_plugin_class('release_vcs', vcs_name) + cls = plugin_manager.get_plugin_class('release_vcs', vcs_name, expected_type=ReleaseVCS) return cls(path) classes_by_level = {} for vcs_name in vcs_types: - cls = plugin_manager.get_plugin_class('release_vcs', vcs_name) + cls = plugin_manager.get_plugin_class('release_vcs', vcs_name, expected_type=ReleaseVCS) result = cls.find_vcs_root(path) if not result: continue @@ -70,7 +70,7 @@ def create_release_vcs(path, vcs_name=None): class ReleaseVCS(object): """A version control system (VCS) used to release Rez packages. """ - def __init__(self, pkg_root, vcs_root=None): + def __init__(self, pkg_root: str, vcs_root=None): if vcs_root is None: result = self.find_vcs_root(pkg_root) if not result: @@ -92,7 +92,7 @@ def name(cls): raise NotImplementedError @classmethod - def find_executable(cls, name): + def find_executable(cls, name: str): exe = which(name) if not exe: raise ReleaseVCSError("Couldn't find executable '%s' for VCS '%s'" @@ -100,7 +100,7 @@ def find_executable(cls, name): return exe @classmethod - def is_valid_root(cls, path): + def is_valid_root(cls, path: str): """Return True if the given path is a valid root directory for this version control system. @@ -118,7 +118,7 @@ def search_parents_for_root(cls): raise NotImplementedError @classmethod - def find_vcs_root(cls, path): + def find_vcs_root(cls, path: str): """Try to find a version control root directory of this type for the given path. @@ -141,7 +141,7 @@ def validate_repostate(self): """Ensure that the VCS working copy is up-to-date.""" raise NotImplementedError - def get_current_revision(self): + def get_current_revision(self) -> object: """Get the current revision, this can be any type (str, dict etc) appropriate to your VCS implementation. @@ -152,7 +152,7 @@ def get_current_revision(self): """ raise NotImplementedError - def get_changelog(self, previous_revision=None, max_revisions=None): + def get_changelog(self, previous_revision=None, max_revisions=None) -> str: """Get the changelog text since the given revision. If previous_revision is not an ancestor (for example, the last release @@ -169,7 +169,7 @@ def get_changelog(self, previous_revision=None, max_revisions=None): """ raise NotImplementedError - def tag_exists(self, tag_name): + def tag_exists(self, tag_name: str) -> bool: """Test if a tag exists in the repo. Args: @@ -180,7 +180,7 @@ def tag_exists(self, tag_name): """ raise NotImplementedError - def create_release_tag(self, tag_name, message=None): + def create_release_tag(self, tag_name: str, message=None): """Create a tag in the repo. Create a tag in the repository representing the release of the @@ -193,7 +193,7 @@ def create_release_tag(self, tag_name, message=None): raise NotImplementedError @classmethod - def export(cls, revision, path): + def export(cls, revision, path: str): """Export the repository to the given path at the given revision. Note: diff --git a/src/rez/resolved_context.py b/src/rez/resolved_context.py index 7e1586b41..39e557901 100644 --- a/src/rez/resolved_context.py +++ b/src/rez/resolved_context.py @@ -2,6 +2,8 @@ # Copyright Contributors to the Rez Project +from __future__ import annotations + from rez import __version__, module_root_path from rez.package_repository import package_repository_manager from rez.solver import SolverCallbackReturn @@ -19,13 +21,13 @@ from rez.utils.memcached import pool_memcached_connections from rez.utils.logging_ import print_error, print_warning from rez.utils.which import which -from rez.rex import RexExecutor, Python, OutputStyle, literal +from rez.rex import ActionInterpreter, RexExecutor, Python, OutputStyle, literal from rez.rex_bindings import VersionBinding, VariantBinding, \ VariantsBinding, RequirementsBinding, EphemeralsBinding, intersects from rez import package_order -from rez.packages import get_variant, iter_packages +from rez.packages import get_variant, iter_packages, Package, Variant from rez.package_filter import PackageFilterList -from rez.package_order import PackageOrderList +from rez.package_order import PackageOrder, PackageOrderList from rez.package_cache import PackageCache from rez.shells import create_shell from rez.exceptions import ResolvedContextError, PackageCommandError, \ @@ -42,6 +44,7 @@ from contextlib import contextmanager from functools import wraps from enum import Enum +from typing import Any, Callable, Iterable, TypeVar, TYPE_CHECKING import getpass import json import socket @@ -51,6 +54,11 @@ import os import os.path +if TYPE_CHECKING: + from rez.solver import SolverState, SupportsWrite + +CallableT = TypeVar("CallableT", bound=Callable) + class RezToolsVisibility(Enum): """Determines if/how rez cli tools are added back to PATH within a @@ -123,6 +131,17 @@ def get_lock_request(name, version, patch_lock, weak=True): return PackageRequest(s) +def _on_success(fn: CallableT) -> CallableT: + @wraps(fn) + def _check(self, *nargs, **kwargs): + if self.status_ == ResolverStatus.solved: + return fn(self, *nargs, **kwargs) + else: + raise ResolvedContextError( + "Cannot perform operation in a failed context") + return _check + + class ResolvedContext(object): """A class that resolves, stores and spawns Rez environments. @@ -136,7 +155,7 @@ class ResolvedContext(object): """ serialize_version = (4, 7) tmpdir_manager = TempDirs(config.context_tmpdir, prefix="rez_context_") - context_tracking_payload = None + context_tracking_payload: dict[str, Any] | None = None context_tracking_lock = threading.Lock() package_cache_present = True local = threading.local() @@ -162,16 +181,29 @@ def __call__(self, state): return self.callback(state) return SolverCallbackReturn.keep_going, '' - def __init__(self, package_requests, verbosity=0, timestamp=None, - building=False, caching=None, package_paths=None, - package_filter=None, package_orderers=None, max_fails=-1, - add_implicit_packages=True, time_limit=-1, callback=None, - package_load_callback=None, buf=None, suppress_passive=False, - print_stats=False, package_caching=None, package_cache_async=None): + def __init__(self, + package_requests: Iterable[str | Requirement], + verbosity=0, + timestamp: float | None = None, + building=False, + caching: bool | None = None, + package_paths: list[str] | None = None, + package_filter: PackageFilterList | None = None, + package_orderers: list[PackageOrder] | None = None, + max_fails=-1, + add_implicit_packages=True, + time_limit=-1, + callback: Callable[[SolverState], tuple[SolverCallbackReturn, str]] | None = None, + package_load_callback: Callable[[Package], Any] | None = None, + buf: SupportsWrite | None = None, + suppress_passive=False, + print_stats=False, + package_caching=None, + package_cache_async=None): """Perform a package resolve, and store the result. Args: - package_requests (list[typing.Union[str, PackageRequest]]): request + package_requests (list[typing.Union[str, Requirement]]): request verbosity (int): Verbosity level. One of [0,1,2]. timestamp (float): Ignore packages released after this epoch time. Packages released at exactly this time will not be ignored. @@ -214,13 +246,15 @@ def __init__(self, package_requests, verbosity=0, timestamp=None, self.requested_timestamp = timestamp self.timestamp = self.requested_timestamp or int(time.time()) self.building = building - self.implicit_packages = [] + self.implicit_packages: list[Requirement] = [] self.caching = config.resolve_caching if caching is None else caching self.verbosity = verbosity - self._package_requests = [] + self._package_requests: list[Requirement] = [] for req in package_requests: if isinstance(req, str): + # FIXME: Requirement seems like it would work fine here. the only difference + # appears to be that PackageRequest does some additional validation req = PackageRequest(req) self._package_requests.append(req) @@ -284,7 +318,7 @@ def __init__(self, package_requests, verbosity=0, timestamp=None, # the pre-resolve bindings. We store these because @late package.py # functions need them, and we cache them to avoid cost - self.pre_resolve_bindings = None + self.pre_resolve_bindings: dict[str, Any] | None = None # suite information self.parent_suite_path = None @@ -355,12 +389,12 @@ def __str__(self): self.status.name, req_str) @property - def success(self): + def success(self) -> bool: """True if the context has been solved, False otherwise.""" return (self.status_ == ResolverStatus.solved) @property - def status(self): + def status(self) -> ResolverStatus: """Return the current status of the context. Returns: @@ -368,7 +402,7 @@ def status(self): """ return self.status_ - def requested_packages(self, include_implicit=False): + def requested_packages(self, include_implicit=False) -> list[Requirement]: """Get packages in the request. Args: @@ -376,7 +410,7 @@ def requested_packages(self, include_implicit=False): to the result. Returns: - list[PackageRequest]: + list[Requirement]: """ if include_implicit: return self._package_requests + self.implicit_packages @@ -384,7 +418,7 @@ def requested_packages(self, include_implicit=False): return self._package_requests @property - def resolved_packages(self): + def resolved_packages(self) -> list[Variant] | None: """Get packages in the resolve. Returns: @@ -393,7 +427,7 @@ def resolved_packages(self): return self._resolved_packages @property - def resolved_ephemerals(self): + def resolved_ephemerals(self) -> list[Requirement] | None: """Get non-conflict ephemerals in the resolve. Returns: @@ -424,7 +458,7 @@ def __eq__(self, other): ) def __hash__(self): - list_ = [] + list_: list[Any] = [] req = self.requested_packages(True) list_.append(tuple(req)) res = self.resolved_packages @@ -441,7 +475,7 @@ def has_graph(self): """Return True if the resolve has a graph.""" return bool((self.graph_ is not None) or self.graph_string) - def get_resolved_package(self, name): + def get_resolved_package(self, name: str) -> Variant | None: """Returns a `Variant` object or None if the package is not in the resolve. """ @@ -610,8 +644,8 @@ def get_patched_request(self, package_requests=None, if len(variant.version) >= rank: version = variant.version.trim(rank - 1) version = next(version) - req = "~%s<%s" % (variant.name, str(version)) - rank_limiters.append(req) + req_str = "~%s<%s" % (variant.name, str(version)) + rank_limiters.append(req_str) request += rank_limiters return request @@ -647,7 +681,7 @@ def graph(self, as_dot=False): return write_dot(self.graph_) - def save(self, path): + def save(self, path: str): """Save the resolved context to file.""" with self._detect_bundle(path): with open(path, 'w') as f: @@ -706,7 +740,7 @@ def read_from_buffer(cls, buf, identifier_str=None): except Exception as e: cls._load_error(e, identifier_str) - def get_resolve_diff(self, other): + def get_resolve_diff(self, other: ResolvedContext): """Get the difference between the resolve in this context and another. The difference is described from the point of view of the current context @@ -745,7 +779,8 @@ def get_resolve_diff(self, other): raise ResolvedContextError("Cannot diff resolves, package search " "paths differ:\n%s" % '\n'.join(diff)) - d = {} + # FIXME: make this a TypedDict + d: dict[str, Any] = {} self_pkgs_ = set(x.parent for x in self._resolved_packages) other_pkgs_ = set(x.parent for x in other._resolved_packages) self_pkgs = self_pkgs_ - other_pkgs_ @@ -831,7 +866,7 @@ def _rt(t): if verbosity: _pr("search paths:", heading) - rows = [] + rows: list[tuple[str, str]] = [] colors = [] for path in self.package_paths: if package_repository_manager.are_same(path, config.local_packages_path): @@ -889,7 +924,7 @@ def _rt(t): return _pr("resolved packages:", heading) - rows = [] + rows3: list[tuple[str, str, str]] = [] colors = [] resolved_packages = self.resolved_packages or [] @@ -929,18 +964,18 @@ def _rt(t): t.append('local') col = local - t = '(%s)' % ', '.join(t) if t else '' - rows.append((pkg.qualified_package_name, location, t)) + t_str = '(%s)' % ', '.join(t) if t else '' + rows3.append((pkg.qualified_package_name, location, t_str)) colors.append(col) # add ephemerals to end of resolved packages list ephemerals = self.resolved_ephemerals or [] ephemerals = sorted(ephemerals, key=lambda x: x.name) for req in ephemerals: - rows.append((str(req), '', "(ephemeral)")) + rows3.append((str(req), '', "(ephemeral)")) colors.append(ephemeral_color) - for col, line in zip(colors, columnise(rows)): + for col, line in zip(colors, columnise(rows3)): _pr(line, col) if verbosity: @@ -1007,7 +1042,7 @@ def print_resolve_diff(self, other, heading=None): b = os.path.basename(other.load_path) heading = (a, b) if isinstance(heading, tuple): - rows.append(list(heading) + [""]) + rows.append(heading + ("",)) rows.append(('-' * len(heading[0]), '-' * len(heading[1]), "")) newer_packages = d.get("newer_packages", {}) @@ -1043,16 +1078,6 @@ def print_resolve_diff(self, other, heading=None): print('\n'.join(columnise(rows))) - def _on_success(fn): - @wraps(fn) - def _check(self, *nargs, **kwargs): - if self.status_ == ResolverStatus.solved: - return fn(self, *nargs, **kwargs) - else: - raise ResolvedContextError( - "Cannot perform operation in a failed context") - return _check - @_on_success def get_dependency_graph(self, as_dot=False): """Generate the dependency graph. @@ -1109,7 +1134,7 @@ def validate(self): raise ResolvedContextError("%s: %s" % (e.__class__.__name__, str(e))) @_on_success - def get_environ(self, parent_environ=None): + def get_environ(self, parent_environ=None) -> dict[str, str]: """Get the environ dict resulting from interpreting this context. Args: @@ -1126,7 +1151,7 @@ def get_environ(self, parent_environ=None): return executor.get_output() @_on_success - def get_key(self, key, request_only=False): + def get_key(self, key: str, request_only=False) -> dict[str, tuple[Variant, Any]]: """Get a data key value for each resolved package. Args: @@ -1150,7 +1175,7 @@ def get_key(self, key, request_only=False): return values @_on_success - def get_tools(self, request_only=False): + def get_tools(self, request_only=False) -> dict[str, tuple[Variant, list[str]]]: """Returns the commandline tools available in the context. Args: @@ -1163,7 +1188,7 @@ def get_tools(self, request_only=False): return self.get_key("tools", request_only=request_only) @_on_success - def get_tool_variants(self, tool_name): + def get_tool_variants(self, tool_name: str) -> set[Variant]: """Get the variant(s) that provide the named tool. If there are more than one variants, the tool is in conflict, and Rez @@ -1184,7 +1209,7 @@ def get_tool_variants(self, tool_name): return variants @_on_success - def get_conflicting_tools(self, request_only=False): + def get_conflicting_tools(self, request_only=False) -> dict[str, set[Variant]]: """Returns tools of the same name provided by more than one package. Args: @@ -1206,7 +1231,7 @@ def get_conflicting_tools(self, request_only=False): return conflicts @_on_success - def get_shell_code(self, shell=None, parent_environ=None, style=OutputStyle.file): + def get_shell_code(self, shell: str | None = None, parent_environ=None, style=OutputStyle.file): """Get the shell code resulting from intepreting this context. Args: @@ -1317,7 +1342,7 @@ def execute_command(self, args, parent_environ=None, **Popen_args): @_on_success def execute_rex_code(self, code, filename=None, shell=None, - parent_environ=None, **Popen_args): + parent_environ: dict[str, str] | None = None, **Popen_args): """Run some rex code in the context. Note: @@ -1346,9 +1371,13 @@ def _actions_callback(executor): **Popen_args) @_on_success - def execute_shell(self, shell=None, parent_environ=None, rcfile=None, - norc=False, stdin=False, command=None, quiet=False, - block=None, actions_callback=None, post_actions_callback=None, + def execute_shell(self, + shell: str | None = None, + parent_environ: dict[str, str] | None = None, + rcfile=None, norc=False, stdin=False, command=None, quiet=False, + block=None, + actions_callback: Callable[[RexExecutor], Any] | None = None, + post_actions_callback: Callable[[RexExecutor], Any] | None = None, context_filepath=None, start_new_session=False, detached=False, pre_command=None, **Popen_args): """Spawn a possibly-interactive shell. @@ -1476,7 +1505,7 @@ def execute_shell(self, shell=None, parent_environ=None, rcfile=None, return p @_on_success - def get_resolve_as_exact_requests(self): + def get_resolve_as_exact_requests(self) -> list[PackageRequest]: """Convert to a package request list of exact resolved package versions. >>> r = ResolvedContext(['foo'] @@ -1487,10 +1516,10 @@ def get_resolve_as_exact_requests(self): List of `PackageRequest`: Context as a list of exact version requests. """ - def to_req(variant): + def to_req(variant: Variant) -> PackageRequest: return PackageRequest(variant.parent.as_exact_requirement()) - return map(to_req, self.resolved_packages) + return [to_req(r) for r in self.resolved_packages] def to_dict(self, fields=None): """Convert context to dict containing only builtin types. @@ -1503,7 +1532,7 @@ def to_dict(self, fields=None): Returns: dict: Dictified context. """ - data = {} + data: dict[str, Any] = {} def _add(field): return (fields is None or field in fields) @@ -1730,7 +1759,7 @@ def _print_version(value): return r - def _execute_bundle_post_actions_callback(self, executor): + def _execute_bundle_post_actions_callback(self, executor: RexExecutor): """ In bundles, you can drop a 'post_commands.py' file (rex) alongside the 'bundle.yaml' file, and it will be sourced after all package commands. @@ -1941,7 +1970,8 @@ def _set_parent_suite(self, suite_path, context_name): self.parent_suite_path = suite_path self.suite_context_name = context_name - def _create_executor(self, interpreter, parent_environ): + def _create_executor(self, interpreter: ActionInterpreter, + parent_environ: dict[str, str] | None) -> RexExecutor: parent_vars = True if config.all_parent_variables \ else config.parent_variables @@ -1962,7 +1992,7 @@ def _get_pre_resolve_bindings(self): return self.pre_resolve_bindings @pool_memcached_connections - def _execute(self, executor): + def _execute(self, executor: RexExecutor): """Bind various info to the execution context """ def normalized(path): @@ -2148,7 +2178,7 @@ def normalized(path): elif mode == RezToolsVisibility.prepend: executor.prepend_rez_path() - def _append_suite_paths(self, executor): + def _append_suite_paths(self, executor: RexExecutor): from rez.suite import Suite mode = SuiteVisibility[config.suite_visibility] diff --git a/src/rez/resolver.py b/src/rez/resolver.py index 921871ce6..ad0a75972 100644 --- a/src/rez/resolver.py +++ b/src/rez/resolver.py @@ -2,9 +2,11 @@ # Copyright Contributors to the Rez Project -from rez.solver import Solver, SolverStatus +from __future__ import annotations + +from rez.solver import Solver, SolverCallbackReturn, SolverState, SolverStatus from rez.package_repository import package_repository_manager -from rez.packages import get_variant, get_last_release_time +from rez.packages import get_variant, get_last_release_time, Variant from rez.package_filter import PackageFilterList, TimestampRule from rez.utils.memcached import memcached_client, pool_memcached_connections from rez.utils.logging_ import log_duration @@ -13,6 +15,11 @@ from contextlib import contextmanager from enum import Enum from hashlib import sha1 +from typing import Callable, TYPE_CHECKING + +if TYPE_CHECKING: + from rez.package_order import PackageOrder + from rez.resolved_context import ResolvedContext class ResolverStatus(Enum): @@ -35,10 +42,21 @@ class Resolver(object): The Resolver uses a combination of Solver(s) and cache(s) to resolve a package request as quickly as possible. """ - def __init__(self, context, package_requests, package_paths, package_filter=None, - package_orderers=None, timestamp=0, callback=None, building=False, - verbosity=False, buf=None, package_load_callback=None, caching=True, - suppress_passive=False, print_stats=False): + def __init__(self, + context: ResolvedContext, + package_requests: list[Requirement], + package_paths: list[str], + package_filter: PackageFilterList | None = None, + package_orderers: list[PackageOrder] | None = None, + timestamp=0, + callback: Callable[[SolverState], tuple[SolverCallbackReturn, str]] | None = None, + building=False, + verbosity=False, + buf=None, + package_load_callback=None, + caching=True, + suppress_passive=False, + print_stats=False): """Create a Resolver. Args: @@ -96,8 +114,8 @@ def __init__(self, context, package_requests, package_paths, package_filter=None self.package_filter = package_filter self.status_ = ResolverStatus.pending - self.resolved_packages_ = None - self.resolved_ephemerals_ = None + self.resolved_packages_: list[Variant] | None = None + self.resolved_ephemerals_: list[Requirement] | None = None self.failure_description = None self.graph_ = None self.from_cache = False @@ -109,7 +127,7 @@ def __init__(self, context, package_requests, package_paths, package_filter=None self._print = config.debug_printer("resolve_memcache") @pool_memcached_connections - def solve(self): + def solve(self) -> None: """Perform the solve. """ with log_duration(self._print, "memcache get (resolve) took %s"): @@ -137,17 +155,17 @@ def status(self): return self.status_ @property - def resolved_packages(self): + def resolved_packages(self) -> list[Variant] | None: """Get the list of resolved packages. Returns: - List of `PackageVariant` objects, or None if the resolve has not + List of `Variant` objects, or None if the resolve has not completed. """ return self.resolved_packages_ @property - def resolved_ephemerals(self): + def resolved_ephemerals(self) -> list[Requirement] | None: """Get the list of resolved ewphemerals. Returns: @@ -167,7 +185,7 @@ def graph(self): """ return self.graph_ - def _get_variant(self, variant_handle): + def _get_variant(self, variant_handle) -> Variant: return get_variant(variant_handle, context=self.context) def _get_cached_solve(self): @@ -345,6 +363,8 @@ def _set_cached_solve(self, solver_dict): release_times_dict = {} variant_states_dict = {} + assert self.resolved_packages_ is not None, \ + "self.resolved_packages_ is set in _set_result when status is 'solved'" for variant in self.resolved_packages_: time_ = get_last_release_time(variant.name, self.package_paths) @@ -391,7 +411,7 @@ def _memcache_key(self, timestamped=False): return str(tuple(t)) - def _solve(self): + def _solve(self) -> Solver: solver = Solver(package_requests=self.package_requests, package_paths=self.package_paths, context=self.context, diff --git a/src/rez/rex.py b/src/rez/rex.py index 1606a12d1..c7dcad6a5 100644 --- a/src/rez/rex.py +++ b/src/rez/rex.py @@ -2,6 +2,8 @@ # Copyright Contributors to the Rez Project +from __future__ import annotations + import os import sys import re @@ -11,6 +13,7 @@ from contextlib import contextmanager from string import Formatter from collections.abc import MutableMapping +from typing import Iterable from rez.system import system from rez.config import config @@ -30,6 +33,7 @@ #=============================================================================== class Action(object): + name: str _registry = [] def __init__(self, *args): @@ -43,11 +47,11 @@ def __eq__(self, other): return (self.name == other.name) and (self.args == other.args) @classmethod - def register_command_type(cls, name, klass): + def register_command_type(cls, name, klass) -> None: cls._registry.append((name, klass)) @classmethod - def register(cls): + def register(cls) -> None: cls.register_command_type(cls.name, cls) @classmethod @@ -57,13 +61,14 @@ def get_command_types(cls): class EnvAction(Action): @property - def key(self): + def key(self) -> str: return self.args[0] @property - def value(self): + def value(self) -> str | None: if len(self.args) == 2: return self.args[1] + return None class Unsetenv(EnvAction): @@ -173,8 +178,8 @@ class ActionManager(object): """Handles the execution book-keeping. Tracks env variable values, and triggers the callbacks of the `ActionInterpreter`. """ - def __init__(self, interpreter, parent_environ=None, parent_variables=None, - formatter=None, verbose=False, env_sep_map=None): + def __init__(self, interpreter: ActionInterpreter, parent_environ: dict[str, str] | None = None, + parent_variables: Iterable[str] | None = None, formatter=None, verbose=False, env_sep_map=None): ''' interpreter: string or `ActionInterpreter` the interpreter to use when executing rex actions @@ -628,7 +633,7 @@ def __init__(self, target_environ=None, passive=False): are skipped. ''' self.passive = passive - self.manager = None + self.manager: ActionManager | None = None if (target_environ is None) or (target_environ is os.environ): self.target_environ = os.environ self.update_session = True @@ -636,7 +641,7 @@ def __init__(self, target_environ=None, passive=False): self.target_environ = target_environ self.update_session = False - def set_manager(self, manager): + def set_manager(self, manager: ActionManager): self.manager = manager def apply_environ(self): @@ -1209,7 +1214,8 @@ class RexExecutor(object): ex.env.FOO_SET = 1 ex.alias('foo','foo -l') """ - def __init__(self, interpreter=None, globals_map=None, parent_environ=None, + def __init__(self, interpreter: ActionInterpreter | None = None, + globals_map=None, parent_environ: dict[str, str] | None = None, parent_variables=None, shebang=True, add_default_namespaces=True): """ interpreter: `ActionInterpreter` or None diff --git a/src/rez/shells.py b/src/rez/shells.py index 6dfe6d676..03d27a057 100644 --- a/src/rez/shells.py +++ b/src/rez/shells.py @@ -5,6 +5,9 @@ """ Pluggable API for creating subshells using different programs, such as bash. """ + +from __future__ import annotations + from rez.rex import RexExecutor, ActionInterpreter, OutputStyle from rez.util import shlex_join, is_non_string_iterable from rez.utils.which import which @@ -17,9 +20,13 @@ import os import os.path from shlex import quote +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import subprocess -def get_shell_types(): +def get_shell_types() -> list[str]: """Returns the available shell types: bash, tcsh etc. Returns: @@ -29,7 +36,7 @@ def get_shell_types(): return list(plugin_manager.get_plugins('shell')) -def get_shell_class(shell=None): +def get_shell_class(shell: str | None = None) -> type[Shell]: """Get the plugin class associated with the given or current shell. Returns: @@ -40,12 +47,12 @@ def get_shell_class(shell=None): if not shell: from rez.system import system shell = system.shell - + assert shell is not None from rez.plugin_managers import plugin_manager - return plugin_manager.get_plugin_class("shell", shell) + return plugin_manager.get_plugin_class("shell", shell, Shell) -def create_shell(shell=None, **kwargs): +def create_shell(shell: str | None = None, **kwargs) -> Shell: """Returns a Shell of the given or current type. Returns: @@ -67,29 +74,29 @@ class Shell(ActionInterpreter): schema_dict = {"prompt": str} @classmethod - def name(cls): + def name(cls) -> str: """Plugin name. """ raise NotImplementedError @classmethod - def executable_name(cls): + def executable_name(cls) -> str: """Name of executable to create shell instance. """ return cls.name() @classmethod - def executable_filepath(cls): + def executable_filepath(cls) -> str: """Get full filepath to executable, or raise if not found. """ return cls.find_executable(cls.executable_name()) @property - def executable(self): + def executable(self) -> str: return self.__class__.executable_filepath() @classmethod - def is_available(cls): + def is_available(cls) -> bool: """Determine if the shell is available to instantiate. Returns: @@ -101,7 +108,7 @@ def is_available(cls): return False @classmethod - def file_extension(cls): + def file_extension(cls) -> str: """Get the file extension associated with the shell. Returns: @@ -132,7 +139,7 @@ def __init__(self): def _addline(self, line): self._lines.append(line) - def get_output(self, style=OutputStyle.file): + def get_output(self, style=OutputStyle.file) -> str: if style == OutputStyle.file: script = '\n'.join(self._lines) + '\n' else: # eval style @@ -198,7 +205,7 @@ def find_executable(cls, name, check_syspaths=False): def spawn_shell(self, context_file, tmpdir, rcfile=None, norc=False, stdin=False, command=None, env=None, quiet=False, pre_command=None, add_rez=True, - package_commands_sourced_first=None, **Popen_args): + package_commands_sourced_first=None, **Popen_args) -> subprocess.Popen: """Spawn a possibly interactive subshell. Args: @@ -233,7 +240,7 @@ def spawn_shell(self, context_file, tmpdir, rcfile=None, norc=False, raise NotImplementedError @classmethod - def convert_tokens(cls, value): + def convert_tokens(cls, value) -> str: """ Converts any token like ${VAR} and $VAR to shell specific form. Uses the ENV_VAR_REGEX to correctly parse tokens. @@ -250,7 +257,7 @@ def convert_tokens(cls, value): ) @classmethod - def get_key_token(cls, key): + def get_key_token(cls, key) -> str: """ Encodes the environment variable into the shell specific form. Shells might implement multiple forms, but the most common/safest @@ -265,7 +272,7 @@ def get_key_token(cls, key): return cls.get_all_key_tokens(key)[0] @classmethod - def get_all_key_tokens(cls, key): + def get_all_key_tokens(cls, key) -> list[str]: """ Encodes the environment variable into the shell specific forms. Shells might implement multiple forms, but the most common/safest @@ -280,7 +287,7 @@ def get_all_key_tokens(cls, key): raise NotImplementedError @classmethod - def line_terminator(cls): + def line_terminator(cls) -> str: """ Returns: str: default line terminator @@ -288,7 +295,7 @@ def line_terminator(cls): raise NotImplementedError @classmethod - def join(cls, command): + def join(cls, command) -> str: """ Note: Default to unix sh/bash- friendly behaviour. @@ -321,14 +328,14 @@ class UnixShell(Shell): r""" A base class for common \*nix shells, such as bash and tcsh. """ - rcfile_arg = None - norc_arg = None - histfile = None - histvar = None + rcfile_arg: str = None + norc_arg: str = None + histfile: str = None + histvar: str = None command_arg = '-c' stdin_arg = '-s' last_command_status = '$?' - syspaths = None + syspaths: list[str] = None # # startup rules @@ -511,21 +518,21 @@ def _create_ex(): % (cmd_str, str(e))) return p - def resetenv(self, key, value, friends=None): + def resetenv(self, key, value, friends=None) -> None: self._addline(self.setenv(key, value)) - def info(self, value): + def info(self, value) -> None: for line in value.split('\n'): line = self.escape_string(line) self._addline('echo %s' % line) - def error(self, value): + def error(self, value) -> None: for line in value.split('\n'): line = self.escape_string(line) self._addline('echo %s 1>&2' % line) # escaping is allowed in args, but not in program string - def command(self, value): + def command(self, value) -> None: if is_non_string_iterable(value): it = iter(value) cmd = EscapedString.disallow(next(it)) @@ -535,12 +542,12 @@ def command(self, value): value = EscapedString.disallow(value) self._addline(value) - def comment(self, value): + def comment(self, value) -> None: value = EscapedString.demote(value) for line in value.split('\n'): self._addline('# %s' % line) - def shebang(self): + def shebang(self) -> None: self._addline("#!%s" % self.executable) @classmethod @@ -548,5 +555,5 @@ def get_all_key_tokens(cls, key): return ["${%s}" % key, "$%s" % key] @classmethod - def line_terminator(cls): + def line_terminator(cls) -> str: return "\n" diff --git a/src/rez/solver.py b/src/rez/solver.py index b128dbc88..4c724dcf3 100644 --- a/src/rez/solver.py +++ b/src/rez/solver.py @@ -11,8 +11,10 @@ See SOLVER.md for an in-depth description of how this module works. """ +from __future__ import annotations + from rez.config import config -from rez.packages import iter_packages +from rez.packages import iter_packages, Package, Variant from rez.package_repository import package_repo_stats from rez.utils.logging_ import print_debug from rez.utils.data_utils import cached_property @@ -21,16 +23,31 @@ from rez.vendor.pygraph.algorithms.accessibility import accessibility from rez.exceptions import PackageNotFoundError, ResolveError, \ PackageFamilyNotFoundError, RezSystemError -from rez.version import VersionRange +from rez.version import Version, VersionRange from rez.version import VersionedObject, Requirement, RequirementList +from rez.utils.typing import SupportsLessThan, Protocol from contextlib import contextmanager from enum import Enum from itertools import product, chain +from typing import Any, Callable, Generator, Iterator, TypeVar, TYPE_CHECKING import copy import time import sys import os +if TYPE_CHECKING: + from rez.resolved_context import ResolvedContext + from rez.package_filter import PackageFilterBase + from rez.package_order import PackageOrder + + +T = TypeVar("T") + + +class SupportsWrite(Protocol): + def write(self, __s: str) -> object: + pass + # a hidden control for forcing to non-optimized solving mode. This is here as # first port of call for narrowing down the cause of a solver bug if we see one @@ -74,9 +91,6 @@ class SolverStatus(Enum): cyclic = ("The solve contains a cycle.", ) unsolved = ("The solve has started, but is not yet solved.", ) - def __init__(self, description): - self.description = description - class SolverCallbackReturn(Enum): """Enum returned by the `callback` callable passed to a `Solver` instance. @@ -87,15 +101,15 @@ class SolverCallbackReturn(Enum): class _Printer(object): - def __init__(self, verbosity, buf=None, suppress_passive=False): + def __init__(self, verbosity, buf: SupportsWrite | None = None, suppress_passive: bool = False): self.verbosity = verbosity self.buf = buf or sys.stdout self.suppress_passive = suppress_passive - self.pending_sub = None + self.pending_sub: str | None = None self.pending_br = False self.last_pr = True - def header(self, txt, *args): + def header(self, txt: str, *args: Any) -> None: if self.verbosity: if self.verbosity > 2: self.pr() @@ -104,11 +118,11 @@ def header(self, txt, *args): if self.verbosity > 2: self.pr('-' * 80) - def subheader(self, txt): + def subheader(self, txt: str) -> None: if self.verbosity > 2: self.pending_sub = txt - def __call__(self, txt, *args): + def __call__(self, txt: str, *args: Any) -> None: if self.verbosity > 2: if self.pending_sub: if self.last_pr: @@ -122,19 +136,19 @@ def __call__(self, txt, *args): self.last_pr = True self.pending_br = False - def passive(self, txt, *args): + def passive(self, txt: str, *args: Any) -> None: if self.suppress_passive: return self(txt, *args) - def br(self): + def br(self) -> None: self.pending_br = True - def pr(self, txt='', *args): + def pr(self, txt: str = '', *args: Any) -> None: print(txt % args, file=self.buf) - def __bool__(self): + def __bool__(self) -> bool: return self.verbosity > 0 @@ -142,51 +156,54 @@ class SolverState(object): """Represent the current state of the solver instance for use with a callback. """ - def __init__(self, num_solves, num_fails, phase): + def __init__(self, num_solves: int, num_fails: int, phase: _ResolvePhase): self.num_solves = num_solves self.num_fails = num_fails self.phase = phase - def __str__(self): + def __str__(self) -> str: return ("solve #%d (%d fails so far): %s" % (self.num_solves, self.num_fails, str(self.phase))) class _Common(object): - def __repr__(self): + def __repr__(self) -> str: return "%s(%s)" % (self.__class__.__name__, str(self)) class Reduction(_Common): """A variant was removed because its dependencies conflicted with another scope in the current phase.""" - def __init__(self, name, version, variant_index, dependency, - conflicting_request): + def __init__(self, name: str, version, variant_index: int | None, dependency: Requirement, + conflicting_request: Requirement): self.name = name self.version = version self.variant_index = variant_index self.dependency = dependency self.conflicting_request = conflicting_request - def reducee_str(self): + def reducee_str(self) -> str: stmt = VersionedObject.construct(self.name, self.version) idx_str = "[]" if self.variant_index is None \ else "[%d]" % self.variant_index return str(stmt) + idx_str - def involved_requirements(self): + def involved_requirements(self) -> list[Requirement]: range_ = VersionRange.from_version(self.version) req = Requirement.construct(self.name, range_) return [req, self.dependency, self.conflicting_request] - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, Reduction): + return NotImplemented + return (self.name == other.name and self.version == other.version and self.variant_index == other.variant_index and self.dependency == other.dependency and self.conflicting_request == other.conflicting_request) - def __str__(self): + def __str__(self) -> str: return "%s (dep(%s) <--!--> %s)" \ % (self.reducee_str(), self.dependency, self.conflicting_request) @@ -194,7 +211,7 @@ def __str__(self): class DependencyConflict(_Common): """A common dependency shared by all variants in a scope, conflicted with another scope in the current phase.""" - def __init__(self, dependency, conflicting_request): + def __init__(self, dependency: Requirement, conflicting_request: Requirement): """ Args: dependency (`Requirement`): Merged requirement from a set of variants. @@ -203,73 +220,76 @@ def __init__(self, dependency, conflicting_request): self.dependency = dependency self.conflicting_request = conflicting_request - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, DependencyConflict): + return NotImplemented + return (self.dependency == other.dependency) \ and (self.conflicting_request == other.conflicting_request) - def __str__(self): + def __str__(self) -> str: return "%s <--!--> %s" % (str(self.dependency), str(self.conflicting_request)) class FailureReason(_Common): - def involved_requirements(self): + def involved_requirements(self) -> list[Requirement]: return [] - def description(self): + def description(self) -> str: return "" class TotalReduction(FailureReason): """All of a scope's variants were reduced away.""" - def __init__(self, reductions): + def __init__(self, reductions: list[Reduction]): self.reductions = reductions - def involved_requirements(self): + def involved_requirements(self) -> list[Requirement]: pkgs = [] for red in self.reductions: pkgs.extend(red.involved_requirements()) return pkgs - def description(self): + def description(self) -> str: return "A package was completely reduced: %s" % str(self) def __eq__(self, other): return (self.reductions == other.reductions) - def __str__(self): + def __str__(self) -> str: return ' '.join(("(%s)" % str(x)) for x in self.reductions) class DependencyConflicts(FailureReason): """A common dependency in a scope conflicted with another scope in the current phase.""" - def __init__(self, conflicts): + def __init__(self, conflicts: list[DependencyConflict]): self.conflicts = conflicts - def involved_requirements(self): + def involved_requirements(self) -> list[Requirement]: pkgs = [] for conflict in self.conflicts: pkgs.append(conflict.dependency) pkgs.append(conflict.conflicting_request) return pkgs - def description(self): + def description(self) -> str: return "The following package conflicts occurred: %s" % str(self) - def __eq__(self, other): + def __eq__(self, other) -> bool: return (self.conflicts == other.conflicts) - def __str__(self): + def __str__(self) -> str: return ' '.join(("(%s)" % str(x)) for x in self.conflicts) class Cycle(FailureReason): """The solve contains a cyclic dependency.""" - def __init__(self, packages): + def __init__(self, packages: list[VersionedObject]): self.packages = packages - def involved_requirements(self): + def involved_requirements(self) -> list[Requirement]: pkgs = [] for pkg in self.packages: range_ = VersionRange.from_version(pkg.version) @@ -277,13 +297,13 @@ def involved_requirements(self): pkgs.append(stmt) return pkgs - def description(self): + def description(self) -> str: return "A cyclic dependency was detected: %s" % str(self) - def __eq__(self, other): + def __eq__(self, other) -> bool: return (self.packages == other.packages) - def __str__(self): + def __str__(self) -> str: stmts = self.packages + self.packages[:1] return " --> ".join(map(str, stmts)) @@ -291,7 +311,7 @@ def __str__(self): class PackageVariant(_Common): """A variant of a package. """ - def __init__(self, variant, building): + def __init__(self, variant: Variant, building: bool): """Create a package variant. Args: @@ -302,23 +322,23 @@ def __init__(self, variant, building): self.building = building @property - def name(self): + def name(self) -> str: return self.variant.name @property - def version(self): + def version(self) -> Version: return self.variant.version @property - def index(self): + def index(self) -> int | None: return self.variant.index @property - def handle(self): + def handle(self) -> dict[str, Any]: return self.variant.handle.to_dict() @cached_property - def requires_list(self): + def requires_list(self) -> RequirementList: """ It is important that this property is calculated lazily. Getting the 'requires' attribute may trigger a package load, which may be avoided if @@ -335,31 +355,31 @@ def requires_list(self): return reqlist @property - def request_fams(self): + def request_fams(self) -> set[str]: return self.requires_list.names @property - def conflict_request_fams(self): + def conflict_request_fams(self) -> set[str]: return self.requires_list.conflict_names - def get(self, pkg_name): + def get(self, pkg_name: str) -> Requirement | None: return self.requires_list.get(pkg_name) - def __eq__(self, other): + def __eq__(self, other) -> bool: return ( self.name == other.name and self.version == other.version and self.index == other.index ) - def __lt__(self, other): + def __lt__(self, other) -> bool: return ( self.name < other.name and self.version < other.version and self.index < other.index ) - def __str__(self): + def __str__(self) -> str: stmt = VersionedObject.construct(self.name, self.version) idxstr = '' if self.index is None else str(self.index) return "%s[%s]" % (str(stmt), idxstr) @@ -370,20 +390,20 @@ class _PackageEntry(object): Holds some extra state data, such as whether the variants are sorted. """ - def __init__(self, package, variants, solver): + def __init__(self, package: Package, variants: list[PackageVariant], solver: Solver): self.package = package self.variants = variants self.solver = solver self.sorted = False @property - def version(self): + def version(self) -> Version: return self.package.version - def __len__(self): + def __len__(self) -> int: return len(self.variants) - def split(self, nvariants): + def split(self, nvariants: int) -> tuple[_PackageEntry, _PackageEntry] | None: if nvariants >= len(self.variants): return None @@ -393,7 +413,7 @@ def split(self, nvariants): entry.sorted = next_entry.sorted = True return entry, next_entry - def sort(self): + def sort(self) -> None: """Sort variants from most correct to consume, to least. Sort rules: @@ -420,7 +440,7 @@ def sort(self): if self.sorted: return - def key(variant): + def key(variant: PackageVariant) -> tuple[SupportsLessThan, ...]: requested_key = [] names = set() @@ -441,18 +461,18 @@ def key(variant): additional_key.append((range_key, request.name)) if (VariantSelectMode[config.variant_select_mode] == VariantSelectMode.version_priority): - k = (requested_key, - -len(additional_key), - additional_key, - variant.index) + return (requested_key, + -len(additional_key), + additional_key, + # None does not support proper sorting, so fall back to int + variant.index or -1) else: # VariantSelectMode.intersection_priority - k = (len(requested_key), - requested_key, - -len(additional_key), - additional_key, - variant.index) - - return k + return (len(requested_key), + requested_key, + -len(additional_key), + additional_key, + # None does not support proper sorting, so fall back to int + variant.index or -1) self.variants.sort(key=key, reverse=True) self.sorted = True @@ -461,7 +481,7 @@ def key(variant): class _PackageVariantList(_Common): """A list of package variants, loaded lazily. """ - def __init__(self, package_name, solver): + def __init__(self, package_name: str, solver: Solver): self.package_name = package_name self.solver = solver @@ -469,7 +489,7 @@ def __init__(self, package_name, solver): # cause package loads (eg, timestamp rules). We only apply filters # during an intersection, which minimises the amount of filtering. # - self.entries = [] + self.entries: list[list[Any]] = [] for package in iter_packages(self.package_name, paths=self.solver.package_paths): @@ -481,7 +501,7 @@ def __init__(self, package_name, solver): "package family not found: %s (searched: %s)" % (package_name, "; ".join(self.solver.package_paths))) - def get_intersection(self, range_): + def get_intersection(self, range_: VersionRange) -> list[_PackageEntry] | None: """Get a list of variants that intersect with the given range. Args: @@ -532,7 +552,7 @@ def get_intersection(self, range_): return result or None - def dump(self): + def dump(self) -> None: print(self.package_name) for package, value in self.entries: @@ -546,7 +566,7 @@ def dump(self): else: print(" %s" % str(package)) - def __str__(self): + def __str__(self) -> str: strs = [] for package, value in self.entries: @@ -565,7 +585,7 @@ def __str__(self): class _PackageVariantSlice(_Common): """A subset of a variant list, but with more dependency-related info.""" - def __init__(self, package_name, entries, solver): + def __init__(self, package_name: str, entries: list[_PackageEntry], solver: Solver): """ Args: entries (list of `_PackageEntry`): result of @@ -580,49 +600,51 @@ def __init__(self, package_name, entries, solver): self.sorted = False # calculated on demand - self._len = None - self._range = None - self._fam_requires = None - self._common_fams = None + self._len: int | None = None + self._range: VersionRange | None = None + self._fam_requires: set[str] | None = None + self._common_fams: set[str] | None = None @property - def pr(self): + def pr(self) -> _Printer: return self.solver.pr @property - def range_(self): + def range_(self) -> VersionRange: if self._range is None: versions = (x.version for x in self.entries) self._range = VersionRange.from_versions(versions) return self._range @property - def fam_requires(self): + def fam_requires(self) -> set[str]: self._update_fam_info() + assert self._fam_requires is not None return self._fam_requires @property - def common_fams(self): + def common_fams(self) -> set[str]: self._update_fam_info() + assert self._common_fams is not None return self._common_fams @property - def extractable(self): + def extractable(self) -> bool: """True if there are possible remaining extractions.""" return not self.extracted_fams.issuperset(self.common_fams) @property - def first_variant(self): + def first_variant(self) -> PackageVariant: entry = self.entries[0] entry.sort() return entry.variants[0] - def iter_variants(self): + def iter_variants(self) -> Iterator[PackageVariant]: for entry in self.entries: for variant in entry.variants: yield variant - def intersect(self, range_): + def intersect(self, range_: VersionRange) -> _PackageVariantSlice | None: self.solver.intersection_broad_tests_count += 1 """Remove variants whose version fall outside of the given range.""" @@ -652,7 +674,7 @@ def intersect(self, range_): self.been_intersected_with.add(range_) return self - def reduce_by(self, package_request): + def reduce_by(self, package_request: Requirement) -> tuple[_PackageVariantSlice | None, list[Reduction]]: """Remove variants whos dependencies conflict with the given package request. @@ -675,14 +697,14 @@ def reduce_by(self, package_request): with self.solver.timed(self.solver.reduction_time): return self._reduce_by(package_request) - def _reduce_by(self, package_request): + def _reduce_by(self, package_request: Requirement) -> tuple[_PackageVariantSlice | None, list[Reduction]]: self.solver.reduction_tests_count += 1 entries = [] reductions = [] conflict_tests = {} - def _conflicts(req_): + def _conflicts(req_: Requirement): # cache conflict tests, since variants often share similar requirements req_s = str(req) result = conflict_tests.get(req_s) @@ -727,7 +749,7 @@ def _conflicts(req_): self.been_reduced_by.add(package_request) return (self, []) - def extract(self): + def extract(self) -> tuple[_PackageVariantSlice, Requirement | None]: """Extract a common dependency. Note that conflict dependencies are never extracted, they are always @@ -741,11 +763,12 @@ def extract(self): # the sort is necessary to ensure solves are deterministic fam = sorted(extractable)[0] - last_range = None + last_range: VersionRange | None = None ranges = set() for variant in self.iter_variants(): req = variant.get(fam) + assert req is not None if req.range != last_range: # will match often, avoids set search ranges.add(req.range) last_range = req.range @@ -758,7 +781,7 @@ def extract(self): common_req = Requirement.construct(fam, range_) return slice_, common_req - def split(self): + def split(self) -> tuple[_PackageVariantSlice, _PackageVariantSlice]: """Split the slice. Returns: @@ -772,7 +795,7 @@ def split(self): # self.sort_versions() - def _split(i_entry, n_variants, common_fams=None): + def _split(i_entry: int, n_variants: int, common_fams=None): # perform a split at a specific point result = self.entries[i_entry].split(n_variants) @@ -812,7 +835,7 @@ def _split(i_entry, n_variants, common_fams=None): return _split(0, 1) # find split point - first variant with no dependency shared with previous - prev = None + prev: tuple[int, int, set[str]] | None = None for i, entry in enumerate(self.entries): # sort the variants. This is done here in order to do the sort as # late as possible, simply to avoid the cost. @@ -821,6 +844,7 @@ def _split(i_entry, n_variants, common_fams=None): for j, variant in enumerate(entry.variants): fams = fams & variant.request_fams if not fams: + assert prev is not None return _split(*prev) prev = (i, j + 1, fams) @@ -832,7 +856,7 @@ def _split(i_entry, n_variants, common_fams=None): "Unexpected solver error: common family(s) still in slice being " "split: slice: %s, family(s): %s" % (self, str(fams))) - def sort_versions(self): + def sort_versions(self) -> None: """Sort entries by version. The order is typically descending, but package order functions can @@ -845,7 +869,7 @@ def sort_versions(self): orderer = get_orderer(self.package_name, orderers=self.solver.package_orderers or {}) - def sort_key(entry): + def sort_key(entry: _PackageEntry) -> SupportsLessThan: return orderer.sort_key(entry.package.name, entry.version) self.entries = sorted(self.entries, key=sort_key, reverse=True) @@ -854,11 +878,11 @@ def sort_key(entry): if self.pr: self.pr("sorted: %s packages: %s", self.package_name, repr(orderer)) - def dump(self): + def dump(self) -> None: print(self.package_name) print('\n'.join(map(str, self.iter_variants()))) - def _copy(self, new_entries): + def _copy(self, new_entries: list[_PackageEntry]) -> _PackageVariantSlice: slice_ = _PackageVariantSlice(package_name=self.package_name, entries=new_entries, solver=self.solver) @@ -868,7 +892,7 @@ def _copy(self, new_entries): slice_.been_intersected_with = self.been_intersected_with.copy() return slice_ - def _update_fam_info(self): + def _update_fam_info(self) -> None: if self._common_fams is not None: return @@ -880,7 +904,7 @@ def _update_fam_info(self): self._fam_requires |= (variant.request_fams | variant.conflict_request_fams) - def __len__(self): + def __len__(self) -> int: if self._len is None: self._len = 0 for entry in self.entries: @@ -888,7 +912,7 @@ def __len__(self): return self._len - def __str__(self): + def __str__(self) -> str: """ foo[2..6(3:4)]* means, 3 versions, 4 variants in 2..6, and at least one family can still be extracted. @@ -907,7 +931,8 @@ def __str__(self): s = "[%s==%s%s]" % (self.package_name, str(variant.version), s_idx) elif nversions == 1: entry = self.entries[0] - indexes = sorted([x.index for x in entry.variants]) + # we expect all variants to have a non-None index, but filter to satisfy mypy + indexes = sorted([x.index for x in entry.variants if x.index is not None]) s_idx = ','.join(str(x) for x in indexes) verstr = str(entry.version) s = "[%s==%s[%s]]" % (self.package_name, verstr, s_idx) @@ -923,11 +948,11 @@ def __str__(self): class PackageVariantCache(object): - def __init__(self, solver): + def __init__(self, solver: Solver): self.solver = solver - self.variant_lists = {} # {package-name: _PackageVariantList} + self.variant_lists: dict[str, _PackageVariantList] = {} # {package-name: _PackageVariantList} - def get_variant_slice(self, package_name, range_): + def get_variant_slice(self, package_name: str, range_: VersionRange) -> _PackageVariantSlice | None: """Get a list of variants from the cache. Args: @@ -958,10 +983,9 @@ class _PackageScope(_Common): or a conflict range. As the resolve progresses, package scopes are narrowed down. """ - def __init__(self, package_request, solver): + def __init__(self, package_request: Requirement, solver: Solver): self.package_name = package_request.name self.solver = solver - self.package_request = None self.variant_slice = None self.pr = solver.pr self.is_ephemeral = (package_request.name.startswith('.')) @@ -978,13 +1002,14 @@ def __init__(self, package_request, solver): package_request.range) raise PackageNotFoundError("Package could not be found: %s" % str(req)) + # This call to _update() will set self.package_request self._update() @property - def is_conflict(self): - return self.package_request and self.package_request.conflict + def is_conflict(self) -> bool: + return bool(self.package_request and self.package_request.conflict) - def intersect(self, range_): + def intersect(self, range_: VersionRange) -> _PackageScope | None: """Intersect this scope with a package range. Returns: @@ -1036,6 +1061,8 @@ def intersect(self, range_): new_slice = self.solver._get_variant_slice( self.package_name, new_range) else: + assert self.variant_slice is not None, \ + "variant_slice should always exist for non-conflicted non-ephemeral requests" new_slice = self.variant_slice.intersect(range_) # intersection reduced the scope to nothing @@ -1056,7 +1083,7 @@ def intersect(self, range_): # intersection did not change the scope return self - def reduce_by(self, package_request): + def reduce_by(self, package_request: Requirement) -> tuple[_PackageScope | None, list[Reduction]]: """Reduce this scope wrt a package request. Returns: @@ -1071,6 +1098,9 @@ def reduce_by(self, package_request): if self.is_conflict or self.is_ephemeral: return (self, []) + assert self.variant_slice is not None, \ + "variant_slice should always exist for non-conflicted non-ephemeral requests" + # perform the reduction new_slice, reductions = self.variant_slice.reduce_by(package_request) @@ -1099,7 +1129,7 @@ def reduce_by(self, package_request): # there was no reduction return (self, []) - def extract(self): + def extract(self) -> tuple[_PackageScope, Requirement | None]: """Extract a common dependency. Returns: @@ -1112,6 +1142,9 @@ def extract(self): if self.is_conflict or self.is_ephemeral: return (self, None) + assert self.variant_slice is not None, \ + "variant_slice should always exist for non-conflicted non-ephemeral requests" + new_slice, package_request = self.variant_slice.extract() if not package_request: return (self, None) @@ -1123,7 +1156,7 @@ def extract(self): self.pr("extracted %s from %s", package_request, self) return (scope, package_request) - def split(self): + def split(self) -> tuple[_PackageScope, _PackageScope] | None: """Split the scope. Returns: @@ -1134,10 +1167,15 @@ def split(self): if ( self.is_conflict or self.is_ephemeral - or len(self.variant_slice) == 1 ): return None + assert self.variant_slice is not None, \ + "variant_slice should always exist for non-conflicted non-ephemeral requests" + + if len(self.variant_slice) == 1: + return None + r = self.variant_slice.split() if r is None: return None @@ -1147,23 +1185,24 @@ def split(self): next_scope = self._copy(next_slice) return (scope, next_scope) - def _copy(self, new_slice): + def _copy(self, new_slice: _PackageVariantSlice) -> _PackageScope: scope = copy.copy(self) scope.variant_slice = new_slice scope._update() return scope - def _is_solved(self): + def _is_solved(self) -> bool: return ( self.is_conflict or self.is_ephemeral or ( - len(self.variant_slice) == 1 + self.variant_slice is not None # should never be None here + and len(self.variant_slice) == 1 and not self.variant_slice.extractable ) ) - def _get_solved_variant(self): + def _get_solved_variant(self) -> PackageVariant | None: if ( self.variant_slice is not None and len(self.variant_slice) == 1 @@ -1173,25 +1212,25 @@ def _get_solved_variant(self): else: return None - def _get_solved_ephemeral(self): + def _get_solved_ephemeral(self) -> Requirement | None: if self.is_ephemeral and not self.is_conflict: return self.package_request else: return None - def _update(self): + def _update(self) -> None: if self.variant_slice is not None: self.package_request = Requirement.construct( self.package_name, self.variant_slice.range_) - def __str__(self): + def __str__(self) -> str: if self.variant_slice is None: return str(self.package_request) else: return str(self.variant_slice) -def _get_dependency_order(g, node_list): +def _get_dependency_order(g: digraph, node_list: list[T]) -> list[T]: """Return list of nodes as close as possible to the ordering in node_list, but with child nodes earlier in the list than parents.""" access_ = accessibility(g) @@ -1230,10 +1269,10 @@ class _ResolvePhase(_Common): If the resolve phase gets to a point where every package scope is solved, then the entire resolve is considered to be solved. """ - def __init__(self, solver): + def __init__(self, solver: Solver): self.solver = solver - self.failure_reason = None - self.extractions = {} + self.failure_reason: FailureReason | None = None + self.extractions: dict[tuple[str, str], Requirement] = {} self.status = SolverStatus.pending self.scopes = [] @@ -1245,21 +1284,21 @@ def __init__(self, solver): self.changed_scopes_i = set(range(len(self.scopes))) @property - def pr(self): + def pr(self) -> _Printer: return self.solver.pr - def solve(self): + def solve(self) -> _ResolvePhase: """Attempt to solve the phase.""" if self.status != SolverStatus.pending: return self scopes = self.scopes[:] - failure_reason = None - extractions = {} + failure_reason: FailureReason | None = None + extractions: dict[tuple[str, str], Requirement] = {} changed_scopes_i = self.changed_scopes_i.copy() - def _create_phase(status=None): + def _create_phase(status: SolverStatus | None = None) -> _ResolvePhase: phase = copy.copy(self) phase.scopes = scopes phase.failure_reason = failure_reason @@ -1281,7 +1320,7 @@ def _create_phase(status=None): # iteratively extract until no more extractions possible while True: self.pr.subheader("EXTRACTING:") - extracted_requests = [] + extracted_requests_ = [] # perform all possible extractions with self.solver.timed(self.solver.extraction_time): @@ -1290,7 +1329,7 @@ def _create_phase(status=None): scope_, extracted_request = scopes[i].extract() if extracted_request: - extracted_requests.append(extracted_request) + extracted_requests_.append(extracted_request) k = (scopes[i].package_name, extracted_request.name) extractions[k] = extracted_request self.solver.extractions_count += 1 @@ -1298,12 +1337,12 @@ def _create_phase(status=None): else: break - if not extracted_requests: + if not extracted_requests_: break # simplify extractions (there may be overlaps) self.pr.subheader("MERGE-EXTRACTIONS:") - extracted_requests = RequirementList(extracted_requests) + extracted_requests = RequirementList(extracted_requests_) if extracted_requests.conflict: # extractions are in conflict req1, req2 = extracted_requests.conflict @@ -1325,21 +1364,21 @@ def _create_phase(status=None): continue # perform the intersection - scope_ = scope.intersect(extracted_req.range) + new_scope = scope.intersect(extracted_req.range) req_fams.append(extracted_req.name) - if scope_ is None: + if new_scope is None: # the scope conflicted with the extraction conflict = DependencyConflict( extracted_req, scope.package_request) failure_reason = DependencyConflicts([conflict]) return _create_phase(SolverStatus.failed) - if scope_ is not scope: + if new_scope is not scope: # the scope was narrowed because it intersected # with an extraction - scopes[i] = scope_ + scopes[i] = new_scope changed_scopes_i.add(i) self.solver.intersections_count += 1 @@ -1446,10 +1485,10 @@ def _create_phase(status=None): # A different order here wouldn't cause an invalid solve, however # rez solves must be deterministic, so this is why we sort. # - pending_reducts = sorted(pending_reducts) + pending_reducts_ = sorted(pending_reducts) - while pending_reducts: - x, y = pending_reducts.pop() + while pending_reducts_: + x, y = pending_reducts_.pop() if x == y: continue @@ -1466,13 +1505,13 @@ def _create_phase(status=None): # other scopes need to reduce against x again for j in all_scopes_i: if j != x: - pending_reducts.append((j, x)) + pending_reducts_.append((j, x)) changed_scopes_i = set() return _create_phase() - def finalise(self): + def finalise(self) -> _ResolvePhase: """Remove conflict requests, detect cyclic dependencies, and reorder packages wrt dependency and then request order. @@ -1483,6 +1522,8 @@ def finalise(self): """ assert self._is_solved() g = self._get_minimal_graph() + assert g is not None, "graph should always be present when solved" + scopes = dict((x.package_name, x) for x in self.scopes if not x.is_conflict) @@ -1493,11 +1534,12 @@ def finalise(self): for fam in fam_cycle: scope = scopes[fam] variant = scope._get_solved_variant() + assert variant is not None, "variant should not be None when scope is solved" stmt = VersionedObject.construct(fam, variant.version) cycle.append(stmt) phase = copy.copy(self) - phase.scopes = scopes.values() + phase.scopes = list(scopes.values()) phase.failure_reason = Cycle(cycle) phase.status = SolverStatus.cyclic return phase @@ -1516,7 +1558,7 @@ def finalise(self): phase.scopes = scopes_ return phase - def split(self): + def split(self) -> tuple[_ResolvePhase, _ResolvePhase]: """Split the phase. When a phase is exhausted, it gets split into a pair of phases to be @@ -1540,7 +1582,7 @@ def split(self): scopes = [] next_scopes = [] - split_i = None + split_i: int | None = None for i, scope in enumerate(self.scopes): if split_i is None: @@ -1572,7 +1614,7 @@ def split(self): next_phase.scopes = next_scopes return (phase, next_phase) - def get_graph(self): + def get_graph(self) -> digraph: """Get the resolve graph. The resolve graph shows what packages were resolved, and the @@ -1597,12 +1639,12 @@ def get_graph(self): node_fontsize = 10 counter = [1] - def _uid(): + def _uid() -> str: id_ = counter[0] counter[0] += 1 return "_%d" % id_ - def _add_edge(id1, id2, arrowsize=0.5): + def _add_edge(id1: str, id2: str, arrowsize=0.5) -> tuple[str, str]: e = (id1, id2) if g.has_edge(e): g.del_edge(e) @@ -1610,30 +1652,30 @@ def _add_edge(id1, id2, arrowsize=0.5): g.add_edge_attribute(e, ("arrowsize", str(arrowsize))) return e - def _add_extraction_merge_edge(id1, id2): + def _add_extraction_merge_edge(id1: str, id2: str): e = _add_edge(id1, id2, 1) g.add_edge_attribute(e, ("arrowhead", "odot")) - def _add_conflict_edge(id1, id2): + def _add_conflict_edge(id1: str, id2: str): e = _add_edge(id1, id2, 1) g.set_edge_label(e, "CONFLICT") g.add_edge_attribute(e, ("style", "bold")) g.add_edge_attribute(e, ("color", "red")) g.add_edge_attribute(e, ("fontcolor", "red")) - def _add_cycle_edge(id1, id2): + def _add_cycle_edge(id1: str, id2: str): e = _add_edge(id1, id2, 1) g.set_edge_label(e, "CYCLE") g.add_edge_attribute(e, ("style", "bold")) g.add_edge_attribute(e, ("color", "red")) g.add_edge_attribute(e, ("fontcolor", "red")) - def _add_reduct_edge(id1, id2, label): + def _add_reduct_edge(id1: str, id2: str, label: str): e = _add_edge(id1, id2, 1) g.set_edge_label(e, label) g.add_edge_attribute(e, ("fontsize", node_fontsize)) - def _add_node(label, color, style): + def _add_node(label: str, color: str, style: str) -> str: attrs = [("label", label), ("fontsize", node_fontsize), ("fillcolor", color), @@ -1642,7 +1684,7 @@ def _add_node(label, color, style): g.add_node(id_, attrs=attrs) return id_ - def _add_request_node(request, initial_request=False): + def _add_request_node(request: Requirement, initial_request: bool = False) -> str: id_ = request_nodes.get(request) if id_ is not None: return id_ @@ -1657,7 +1699,7 @@ def _add_request_node(request, initial_request=False): request_nodes[request] = id_ return id_ - def _add_scope_node(scope): + def _add_scope_node(scope: _PackageScope) -> str: id_ = scope_nodes.get(scope.package_name) if id_ is not None: return id_ @@ -1681,7 +1723,7 @@ def _add_scope_node(scope): scope_requests[id_] = scope.package_request return id_ - def _add_reduct_node(request): + def _add_reduct_node(request: Requirement) -> str: return _add_node(str(request), node_color, "filled,dashed") # -- generate the graph @@ -1736,6 +1778,7 @@ def _add_reduct_node(request): reqlist = RequirementList(requests) if not reqlist.conflict: merged_request = reqlist.get(fam) + assert merged_request is not None for request in requests: if merged_request != request: id1 = _add_request_node(request) @@ -1749,7 +1792,7 @@ def _add_reduct_node(request): for conflict in fr.conflicts: conflicting_request = conflict.conflicting_request scope_n = scope_nodes.get(conflicting_request.name) - scope_r = scope_requests.get(scope_n) + scope_r = scope_requests.get(scope_n) if scope_n is not None else None if scope_n is not None \ and scope_r is not None \ @@ -1807,7 +1850,7 @@ def _add_reduct_node(request): if not g.neighbors(id1): # leaf node id2 = scope_nodes.get(request.name) if id2 is not None: - scope = scopes.get(request.name) + scope = scopes[request.name] if not request.conflicts_with(scope.package_request): _add_edge(id1, id2) @@ -1825,7 +1868,7 @@ def _add_reduct_node(request): return g - def _get_minimal_graph(self): + def _get_minimal_graph(self) -> digraph | None: if not self._is_solved(): return None @@ -1852,13 +1895,13 @@ def _get_minimal_graph(self): return g - def _is_solved(self): + def _is_solved(self) -> bool: for scope in self.scopes: if not scope._is_solved(): return False return True - def _get_solved_variants(self): + def _get_solved_variants(self) -> list[PackageVariant]: variants = [] for scope in self.scopes: variant = scope._get_solved_variant() @@ -1867,7 +1910,7 @@ def _get_solved_variants(self): return variants - def _get_solved_ephemerals(self): + def _get_solved_ephemerals(self) -> list[Requirement]: ephemerals = [] for scope in self.scopes: ephemeral = scope._get_solved_ephemeral() @@ -1876,7 +1919,7 @@ def _get_solved_ephemerals(self): return ephemerals - def __str__(self): + def __str__(self) -> str: return ' '.join(str(x) for x in self.scopes) @@ -1889,11 +1932,21 @@ class Solver(_Common): """ max_verbosity = 3 - def __init__(self, package_requests, package_paths, context=None, - package_filter=None, package_orderers=None, callback=None, - building=False, optimised=True, verbosity=0, buf=None, - package_load_callback=None, prune_unfailed=True, - suppress_passive=False, print_stats=False): + def __init__(self, + package_requests: list[Requirement], + package_paths: list[str], + context: ResolvedContext | None = None, + package_filter: PackageFilterBase | None = None, + package_orderers: list[PackageOrder] | None = None, + callback: Callable[[SolverState], tuple[SolverCallbackReturn, str]] | None = None, + building: bool = False, + optimised: bool = True, + verbosity: int = 0, + buf: SupportsWrite | None = None, + package_load_callback: Callable[[Package], Any] | None = None, + prune_unfailed: bool = True, + suppress_passive: bool = False, + print_stats: bool = False): """Create a Solver. Args: @@ -1934,7 +1987,6 @@ def __init__(self, package_requests, package_paths, context=None, self.prune_unfailed = prune_unfailed self.package_load_callback = package_load_callback self.building = building - self.request_list = None self.context = context self.pr = _Printer(verbosity, buf=buf, suppress_passive=suppress_passive) @@ -1946,14 +1998,16 @@ def __init__(self, package_requests, package_paths, context=None, else: self.optimised = optimised - self.phase_stack = None - self.failed_phase_list = None - self.abort_reason = None - self.callback_return = None - self.depth_counts = None - self.solve_begun = None - self.solve_time = None - self.load_time = None + # these values are all set in _init() + self.phase_stack: list[_ResolvePhase] + self.failed_phase_list: list[_ResolvePhase] + self.depth_counts: dict + self.solve_begun: bool + self.solve_time: float + self.load_time: float + + self.abort_reason: str | None = None + self.callback_return: SolverCallbackReturn | None = None # advanced solve metrics self.solve_count = 0 @@ -2000,14 +2054,14 @@ def __init__(self, package_requests, package_paths, context=None, self._push_phase(phase) @contextmanager - def timed(self, target): + def timed(self, target: list[float]) -> Generator: t = time.time() yield secs = time.time() - t target[0] += secs @property - def status(self): + def status(self) -> SolverStatus: """Return the current status of the solve. Returns: @@ -2035,12 +2089,12 @@ def status(self): return st @property - def num_solves(self): + def num_solves(self) -> int: """Return the number of solve steps that have been executed.""" return self.solve_count @property - def num_fails(self): + def num_fails(self) -> int: """Return the number of failed solve steps that have been executed. Note that num_solves is inclusive of failures.""" n = len(self.failed_phase_list) @@ -2049,12 +2103,12 @@ def num_fails(self): return n @property - def cyclic_fail(self): + def cyclic_fail(self) -> bool: """Return True if the solve failed due to a cycle, False otherwise.""" return (self.phase_stack[-1].status == SolverStatus.cyclic) @property - def resolved_packages(self): + def resolved_packages(self) -> list[PackageVariant] | None: """Return a list of resolved variants. Returns: @@ -2068,7 +2122,7 @@ def resolved_packages(self): return final_phase._get_solved_variants() @property - def resolved_ephemerals(self): + def resolved_ephemerals(self) -> list[Requirement] | None: """Return the list of final ephemeral package ranges. Note that conflict ephemerals are not included. @@ -2083,15 +2137,15 @@ def resolved_ephemerals(self): final_phase = self.phase_stack[-1] return final_phase._get_solved_ephemerals() - def reset(self): + def reset(self) -> None: """Reset the solver, removing any current solve.""" if not self.request_list.conflict: - phase = _ResolvePhase(self.request_list.requirements, solver=self) + phase = _ResolvePhase(solver=self) self.pr("resetting...") self._init() self._push_phase(phase) - def solve(self): + def solve(self) -> None: """Attempt to solve the request. """ if self.solve_begun: @@ -2122,7 +2176,7 @@ def solve(self): print(pformat(data), file=(self.buf or sys.stdout)) @property - def solve_stats(self): + def solve_stats(self) -> dict[str, dict[str, Any]]: extraction_stats = { "extraction_time": self.extraction_time[0], "num_extractions": self.extractions_count @@ -2158,7 +2212,7 @@ def solve_stats(self): "reductions": reduction_stats } - def solve_step(self): + def solve_step(self) -> None: """Perform a single solve step. """ self.solve_begun = True @@ -2209,7 +2263,7 @@ def solve_step(self): assert new_phase.status == SolverStatus.exhausted self._push_phase(new_phase) - def failure_reason(self, failure_index=None): + def failure_reason(self, failure_index: int | None = None) -> FailureReason | None: """Get the reason for a failure. Args: @@ -2228,7 +2282,7 @@ def failure_reason(self, failure_index=None): phase, _ = self._get_failed_phase(failure_index) return phase.failure_reason - def failure_description(self, failure_index=None): + def failure_description(self, failure_index: int | None = None) -> str: """Get a description of the failure. This differs from `failure_reason` - in some cases, such as when a @@ -2238,7 +2292,7 @@ def failure_description(self, failure_index=None): _, description = self._get_failed_phase(failure_index) return description - def failure_packages(self, failure_index=None): + def failure_packages(self, failure_index: int | None = None) -> list[Requirement] | None: """Get packages involved in a failure. Args: @@ -2251,7 +2305,7 @@ def failure_packages(self, failure_index=None): fr = phase.failure_reason return fr.involved_requirements() if fr else None - def get_graph(self): + def get_graph(self) -> digraph: """Returns the most recent solve graph. This gives a graph showing the latest state of the solve. The specific @@ -2267,11 +2321,12 @@ def get_graph(self): st = self.status if st in (SolverStatus.solved, SolverStatus.unsolved): phase = self._latest_nonfailed_phase() + assert phase is not None, "Should only be None if status is failed" return phase.get_graph() else: return self.get_fail_graph() - def get_fail_graph(self, failure_index=None): + def get_fail_graph(self, failure_index: int | None = None) -> digraph: """Returns a graph showing a solve failure. Args: @@ -2283,7 +2338,7 @@ def get_fail_graph(self, failure_index=None): phase, _ = self._get_failed_phase(failure_index) return phase.get_graph() - def dump(self): + def dump(self) -> None: """Print a formatted summary of the current solve state.""" from rez.utils.formatting import columnise @@ -2291,7 +2346,7 @@ def dump(self): for i, phase in enumerate(self.phase_stack): rows.append((self._depth_label(i), phase.status, str(phase))) - print("status: %s (%s)" % (self.status.name, self.status.description)) + print("status: %s (%s)" % (self.status.name, self.status.value[0])) print("initial request: %s" % str(self.request_list)) print() print("solve stack:") @@ -2305,7 +2360,7 @@ def dump(self): print("previous failures:") print('\n'.join(columnise(rows))) - def _init(self): + def _init(self) -> None: self.phase_stack = [] self.failed_phase_list = [] self.depth_counts = {} @@ -2329,7 +2384,7 @@ def _init(self): self.reduction_time = [0.0] self.reduction_test_time = [0.0] - def _latest_nonfailed_phase(self): + def _latest_nonfailed_phase(self) -> _ResolvePhase | None: if self.status == SolverStatus.failed: return None @@ -2338,7 +2393,7 @@ def _latest_nonfailed_phase(self): return phase assert False # should never get here - def _do_callback(self): + def _do_callback(self) -> bool: keep_going = True if self.callback: phase = self._latest_nonfailed_phase() @@ -2358,13 +2413,13 @@ def _do_callback(self): return keep_going - def _get_variant_slice(self, package_name, range_): + def _get_variant_slice(self, package_name: str, range_: VersionRange) -> _PackageVariantSlice | None: slice_ = self.package_cache.get_variant_slice( package_name=package_name, range_=range_) return slice_ - def _push_phase(self, phase): + def _push_phase(self, phase: _ResolvePhase) -> None: depth = len(self.phase_stack) count = self.depth_counts.get(depth, -1) + 1 self.depth_counts[depth] = count @@ -2374,14 +2429,14 @@ def _push_phase(self, phase): dlabel = self._depth_label() self.pr("pushed %s: %s", dlabel, phase) - def _pop_phase(self): + def _pop_phase(self) -> _ResolvePhase: dlabel = self._depth_label() phase = self.phase_stack.pop() if self.pr: self.pr("popped %s: %s", dlabel, phase) return phase - def _get_failed_phase(self, index=None): + def _get_failed_phase(self, index: int | None = None) -> tuple[_ResolvePhase, str]: # returns (phase, fail_description) prepend_abort_reason = False fails = self.failed_phase_list @@ -2412,19 +2467,19 @@ def _get_failed_phase(self, index=None): return phase, fail_description - def _depth_label(self, depth=None): + def _depth_label(self, depth: int | None = None) -> str: if depth is None: depth = len(self.phase_stack) - 1 count = self.depth_counts[depth] return "{%d,%d}" % (depth, count) - def __str__(self): + def __str__(self) -> str: return "%s %s %s" % (self.status, self._depth_label(), str(self.phase_stack[-1])) -def _short_req_str(package_request): +def _short_req_str(package_request: Requirement) -> str: """print shortened version of '==X|==Y|==Z' ranged requests.""" if not package_request.conflict: versions = package_request.range.to_versions() diff --git a/src/rez/status.py b/src/rez/status.py index 5be0b9d76..40e91a2e6 100644 --- a/src/rez/status.py +++ b/src/rez/status.py @@ -157,19 +157,19 @@ def print_tools(self, pattern=None, buf=sys.stdout): if pattern and not fnmatch(tool, pattern): continue - label = [] + label_parts = [] color = None path = which(tool) if path: path_ = os.path.join(suite.tools_path, tool) if path != path_: - label.append("(hidden by unknown tool '%s')" % path) + label_parts.append("(hidden by unknown tool '%s')" % path) color = warning variant = d["variant"] if isinstance(variant, set): pkg_str = ", ".join(variant) - label.append("(in conflict)") + label_parts.append("(in conflict)") color = critical else: pkg_str = variant.qualified_package_name @@ -178,7 +178,7 @@ def print_tools(self, pattern=None, buf=sys.stdout): if orig_tool == tool: orig_tool = '-' - label = ' '.join(label) + label = ' '.join(label_parts) source = ("context '%s' in suite '%s'" % (d["context_name"], suite.load_path)) diff --git a/src/rez/suite.py b/src/rez/suite.py index 7b84571e8..5b0f44148 100644 --- a/src/rez/suite.py +++ b/src/rez/suite.py @@ -2,6 +2,8 @@ # Copyright Contributors to the Rez Project +from __future__ import annotations + from rez.utils.execution import create_forwarding_script from rez.exceptions import SuiteError, ResolvedContextError from rez.resolved_context import ResolvedContext @@ -12,12 +14,40 @@ from rez.vendor.yaml.error import YAMLError from rez.utils.yaml import dump_yaml from collections import defaultdict +from typing import TYPE_CHECKING, Any import os import os.path import shutil import sys +if TYPE_CHECKING: + from rez.packages import Variant + from typing import TypedDict + + # FIXME: move this out of TYPE_CHECKING block when python 3.7 support is dropped + class Tool(TypedDict): + tool_name: str + tool_alias: str + context_name: str + variant: Variant | set[Variant] + + class Context(TypedDict): + name: str + context: ResolvedContext + tool_aliases: dict[str, str] + hidden_tools: set[str] + priority: int + prefix_char: str | None + loaded: bool + prefix: str + suffix: str + +else: + Tool = dict + Context = dict + + class Suite(object): """A collection of contexts. @@ -43,15 +73,15 @@ class Suite(object): def __init__(self): """Create a suite.""" self.load_path = None - self.contexts = {} + self.contexts: dict[str, Context] = {} self.next_priority = 1 - self.tools = None - self.tool_conflicts = None - self.hidden_tools = None + self.tools: dict[str, Tool] | None = None + self.tool_conflicts: defaultdict[str, list[Tool]] | None = None + self.hidden_tools: list[Tool] | None = None @property - def context_names(self): + def context_names(self) -> list[str]: """Get the names of the contexts in the suite. Reurns: @@ -105,7 +135,7 @@ def context(self, name): data["loaded"] = True return context - def add_context(self, name, context, prefix_char=None): + def add_context(self, name: str, context: ResolvedContext, prefix_char=None): """Add a context to the suite. Args: @@ -117,12 +147,12 @@ def add_context(self, name, context, prefix_char=None): if not context.success: raise SuiteError("Context is not resolved: %r" % name) - self.contexts[name] = dict(name=name, - context=context.copy(), - tool_aliases={}, - hidden_tools=set(), - priority=self._next_priority, - prefix_char=prefix_char) + self.contexts[name] = Context(name=name, + context=context.copy(), + tool_aliases={}, + hidden_tools=set(), + priority=self._next_priority, + prefix_char=prefix_char) self._flush_tools() def find_contexts(self, in_request=None, in_resolve=None): @@ -167,7 +197,7 @@ def _in_resolve(name): names = [x for x in names if _in_resolve(x)] return names - def remove_context(self, name): + def remove_context(self, name: str): """Remove a context from the suite. Args: @@ -305,6 +335,7 @@ def get_tools(self): a tool of the same name), this will be a set of Variants. """ self._update_tools() + assert self.tools is not None return self.tools def get_tool_filepath(self, tool_alias): @@ -327,7 +358,7 @@ def get_tool_filepath(self, tool_alias): else: return None - def get_tool_context(self, tool_alias): + def get_tool_context(self, tool_alias: str) -> str | None: """Given a visible tool alias, return the name of the context it belongs to. @@ -344,7 +375,7 @@ def get_tool_context(self, tool_alias): return data["context_name"] return None - def get_hidden_tools(self): + def get_hidden_tools(self) -> list[Tool]: """Get the tools hidden in this suite. Hidden tools are those that have been explicitly hidden via `hide_tool`. @@ -358,18 +389,20 @@ def get_hidden_tools(self): - variant (`Variant`): Variant providing the tool. """ self._update_tools() + assert self.hidden_tools is not None return self.hidden_tools - def get_conflicting_aliases(self): + def get_conflicting_aliases(self) -> list[str]: """Get a list of tool aliases that have one or more conflicts. Returns: List of strings. """ self._update_tools() + assert self.tool_conflicts is not None return list(self.tool_conflicts.keys()) - def get_alias_conflicts(self, tool_alias): + def get_alias_conflicts(self, tool_alias: str) -> list[Tool] | None: """Get a list of conflicts on the given tool alias. Args: @@ -383,9 +416,10 @@ def get_alias_conflicts(self, tool_alias): - variant (`Variant`): Variant providing the tool. """ self._update_tools() + assert self.tool_conflicts is not None return self.tool_conflicts.get(tool_alias) - def validate(self): + def validate(self) -> None: """Validate the suite.""" for context_name in self.context_names: context = self.context(context_name) @@ -398,7 +432,7 @@ def validate(self): def to_dict(self): contexts_ = {} for k, data in self.contexts.items(): - data_ = data.copy() + data_: dict[str, Any] = data.copy() if "context" in data_: del data_["context"] if "loaded" in data_: @@ -614,8 +648,8 @@ def _get_row(entry): else: context_names = sorted(self.contexts.keys()) - rows = [["TOOL", "ALIASING", "PACKAGE", "CONTEXT", ""], - ["----", "--------", "-------", "-------", ""]] + rows = [("TOOL", "ALIASING", "PACKAGE", "CONTEXT", ""), + ("----", "--------", "-------", "-------", "")] colors = [None, None] entries_dict = defaultdict(list) @@ -666,7 +700,7 @@ def _get_row(entry): else: _pr("No tools available.") - def _context(self, name): + def _context(self, name: str) -> Context: data = self.contexts.get(name) if not data: raise SuiteError("No such context: %r" % name) @@ -679,11 +713,11 @@ def _context_path(self, name, suite_path=None): filepath = os.path.join(suite_path, "contexts", "%s.rxt" % name) return filepath - def _sorted_contexts(self): + def _sorted_contexts(self) -> list[Context]: return sorted(self.contexts.values(), key=lambda x: x["priority"]) @property - def _next_priority(self): + def _next_priority(self) -> int: p = self.next_priority self.next_priority += 1 return p @@ -725,7 +759,7 @@ def _update_tools(self): if alias is None: alias = "%s%s%s" % (prefix, tool_name, suffix) - entry = dict(tool_name=tool_name, + entry = Tool(tool_name=tool_name, tool_alias=alias, context_name=context_name, variant=variant) diff --git a/src/rez/system.py b/src/rez/system.py index 8486a71af..7793b68f4 100644 --- a/src/rez/system.py +++ b/src/rez/system.py @@ -61,7 +61,7 @@ def variant(self): # TODO: move shell detection into shell plugins @cached_property - def shell(self): + def shell(self) -> str: """Get the current shell. Returns: @@ -88,7 +88,7 @@ def shell(self): args = ['ps', '-o', 'args=', '-p', str(parent_pid)] proc = sp.Popen(args, stdout=sp.PIPE, text=True) output = proc.communicate()[0] - shell = os.path.basename(output.strip().split()[0]).replace('-', '') + shell = os.path.basename(output.decode().strip().split()[0]).replace('-', '') except Exception: pass diff --git a/src/rez/utils/__init__.py b/src/rez/utils/__init__.py index 58a0da7ec..fbf44b704 100644 --- a/src/rez/utils/__init__.py +++ b/src/rez/utils/__init__.py @@ -4,6 +4,7 @@ import sys from contextlib import contextmanager +from typing import NoReturn @contextmanager @@ -11,11 +12,11 @@ def with_noop(): yield -def reraise(exc, new_exc_cls): +def reraise(exc, new_exc_cls) -> NoReturn: traceback = sys.exc_info()[2] # TODO test this. - def reraise_(tp, value, tb=None): + def reraise_(tp, value, tb=None) -> NoReturn: try: if value is None: value = tp() diff --git a/src/rez/utils/backcompat.py b/src/rez/utils/backcompat.py index 3afd42e3b..606b5b70d 100644 --- a/src/rez/utils/backcompat.py +++ b/src/rez/utils/backcompat.py @@ -52,7 +52,7 @@ def convert_old_command_expansions(command): within_unescaped_quotes_regex = re.compile('(? str: """Converts old-style package commands into equivalent Rex code.""" from rez.config import config from rez.utils.logging_ import print_debug diff --git a/src/rez/utils/data_utils.py b/src/rez/utils/data_utils.py index e5f4c009e..a856ffa2b 100644 --- a/src/rez/utils/data_utils.py +++ b/src/rez/utils/data_utils.py @@ -5,6 +5,8 @@ """ Utilities related to managing data types. """ +from __future__ import annotations + import os.path import json import functools @@ -12,6 +14,9 @@ from rez.vendor.schema.schema import Schema, Optional from threading import Lock +from typing import Generic, TypeVar, TYPE_CHECKING + +T = TypeVar("T") class ModifyList(object): @@ -213,50 +218,53 @@ def get_dict_diff_str(d1, d2, title): return '\n'.join(lines) -class cached_property(object): - """Simple property caching descriptor. - - Example: - - >>> class Foo(object): - >>> @cached_property - >>> def bah(self): - >>> print('bah') - >>> return 1 - >>> - >>> f = Foo() - >>> f.bah - bah - 1 - >>> f.bah - 1 - """ - def __init__(self, func, name=None): - self.func = func - # Make sure that Sphinx autodoc can follow and get the docstring from our wrapped function. - functools.update_wrapper(self, func) - self.name = name or func.__name__ - - def __get__(self, instance, owner=None): - if instance is None: - return self - - result = self.func(instance) - try: - setattr(instance, self.name, result) - except AttributeError: - raise AttributeError("can't set attribute %r on %r" - % (self.name, instance)) - return result +if TYPE_CHECKING: + cached_property = property +else: + class cached_property(object): + """Simple property caching descriptor. + + Example: + + >>> class Foo(object): + >>> @cached_property + >>> def bah(self): + >>> print('bah') + >>> return 1 + >>> + >>> f = Foo() + >>> f.bah + bah + 1 + >>> f.bah + 1 + """ + def __init__(self, func, name=None): + self.func = func + # Make sure that Sphinx autodoc can follow and get the docstring from our wrapped function. + functools.update_wrapper(self, func) + self.name = name or func.__name__ + + def __get__(self, instance, owner=None): + if instance is None: + return self + + result = self.func(instance) + try: + setattr(instance, self.name, result) + except AttributeError: + raise AttributeError("can't set attribute %r on %r" + % (self.name, instance)) + return result - # This is to silence Sphinx that complains that cached_property is not a callable. - def __call__(self): - raise RuntimeError("@cached_property should not be called.") + # This is to silence Sphinx that complains that cached_property is not a callable. + def __call__(self): + raise RuntimeError("@cached_property should not be called.") - @classmethod - def uncache(cls, instance, name): - if hasattr(instance, name): - delattr(instance, name) + @classmethod + def uncache(cls, instance, name): + if hasattr(instance, name): + delattr(instance, name) class cached_class_property(object): @@ -293,16 +301,16 @@ def __get__(self, instance, owner=None): return result -class LazySingleton(object): +class LazySingleton(Generic[T]): """A threadsafe singleton that initialises when first referenced.""" - def __init__(self, instance_class, *nargs, **kwargs): + def __init__(self, instance_class: type[T], *nargs, **kwargs): self.instance_class = instance_class self.nargs = nargs self.kwargs = kwargs self.lock = Lock() - self.instance = None + self.instance: T | None = None - def __call__(self): + def __call__(self) -> T: if self.instance is None: try: self.lock.acquire() diff --git a/src/rez/utils/filesystem.py b/src/rez/utils/filesystem.py index fcf0db4f9..06c08beb3 100644 --- a/src/rez/utils/filesystem.py +++ b/src/rez/utils/filesystem.py @@ -115,6 +115,7 @@ def make_path_writable(path): yield finally: if new_mode != orig_mode: + assert orig_mode is not None os.chmod(path, orig_mode) diff --git a/src/rez/utils/patching.py b/src/rez/utils/patching.py index 1252867bb..a422e9831 100644 --- a/src/rez/utils/patching.py +++ b/src/rez/utils/patching.py @@ -42,8 +42,8 @@ def get_patched_request(requires, patchlist): '^': (True, True, True) } - requires = [Requirement(x) if not isinstance(x, Requirement) else x - for x in requires] + requires: list[Requirement | None] = [ + Requirement(x) if not isinstance(x, Requirement) else x for x in requires] appended = [] for patch in patchlist: diff --git a/src/rez/utils/platform_.py b/src/rez/utils/platform_.py index f0901a1ca..d397cfad0 100644 --- a/src/rez/utils/platform_.py +++ b/src/rez/utils/platform_.py @@ -2,6 +2,8 @@ # Copyright Contributors to the Rez Project +from __future__ import annotations + import platform import os import os.path @@ -18,7 +20,7 @@ class Platform(object): """Abstraction of a platform. """ - name = None + name: str def __init__(self): pass @@ -555,7 +557,8 @@ def _difftool(self): # singleton -platform_ = None +# FIXME: is is valid for platform_ to be None? +platform_: Platform = None name = platform.system().lower() if name == "linux": platform_ = LinuxPlatform() diff --git a/src/rez/utils/resources.py b/src/rez/utils/resources.py index 372d1455f..3c3c88922 100644 --- a/src/rez/utils/resources.py +++ b/src/rez/utils/resources.py @@ -33,6 +33,8 @@ See the 'pets' unit test in tests/test_resources.py for a complete example. """ +from __future__ import annotations + from functools import lru_cache from rez.utils.data_utils import cached_property, AttributeForwardMeta, \ @@ -41,6 +43,15 @@ from rez.exceptions import ResourceError from rez.utils.logging_ import print_debug +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + # this is not available in typing until 3.11, but due to __future__.annotations + # we can use it without really importing it + from typing import Self + from rez.vendor.schema.schema import Schema + from rez.package_repository import PackageRepository + class Resource(object, metaclass=LazyAttributeMeta): """Abstract base class for a data resource. @@ -69,14 +80,19 @@ class Resource(object, metaclass=LazyAttributeMeta): `validated_data` function, and test full validation using `validate_data`. """ #: Unique identifier of the resource type. - key = None + key: str = None #: Schema for the resource data. #: Must validate a dict. Can be None, in which case the resource does #: not load any data. - schema = None + schema: Schema | None = None #: The exception type to raise on key validation failure. schema_error = Exception + if TYPE_CHECKING: + # all Resources that are acquired using PackageRepository.get_resource + # have this attribute added to them + _repository: PackageRepository + @classmethod def normalize_variables(cls, variables): """Give subclasses a chance to standardize values for certain variables @@ -87,7 +103,7 @@ def __init__(self, variables=None): self.variables = self.normalize_variables(variables or {}) @cached_property - def handle(self): + def handle(self) -> ResourceHandle: """Get the resource handle.""" return ResourceHandle(self.key, self.variables) @@ -105,10 +121,10 @@ def get(self, key, default=None): """Get the value of a resource variable.""" return self.variables.get(key, default) - def __str__(self): + def __str__(self) -> str: return "%s%r" % (self.key, self.variables) - def __repr__(self): + def __repr__(self) -> str: return "%s(%r)" % (self.__class__.__name__, self.variables) def __hash__(self): @@ -139,7 +155,7 @@ class ResourceHandle(object): A handle uniquely identifies a resource. A handle can be stored and used with a `ResourcePool` to retrieve the same resource at a later date. """ - def __init__(self, key, variables=None): + def __init__(self, key: str, variables=None): self.key = key self.variables = variables or {} @@ -154,7 +170,7 @@ def to_dict(self): return dict(key=self.key, variables=self.variables) @classmethod - def from_dict(cls, d): + def from_dict(cls, d) -> Self: """Return a `ResourceHandle` instance from a serialized dict This should ONLY be used with dicts created with ResourceHandle.to_dict; @@ -169,10 +185,10 @@ def _hashable_repr(self): tuple(sorted(self.variables.items())) ) - def __str__(self): + def __str__(self) -> str: return str(self.to_dict()) - def __repr__(self): + def __repr__(self) -> str: return "%s(%r, %r)" % (self.__class__.__name__, self.key, self.variables) def __eq__(self, other): @@ -194,11 +210,11 @@ class ResourcePool(object): existence of the resource before creating one from a pool. """ def __init__(self, cache_size=None): - self.resource_classes = {} + self.resource_classes: dict[str, type[Resource]] = {} cache = lru_cache(maxsize=cache_size) self.cached_get_resource = cache(self._get_resource) - def register_resource(self, resource_class): + def register_resource(self, resource_class: type[Resource]) -> None: resource_key = resource_class.key assert issubclass(resource_class, Resource) assert resource_key is not None @@ -216,20 +232,20 @@ def register_resource(self, resource_class): self.resource_classes[resource_key] = resource_class - def get_resource_from_handle(self, resource_handle): + def get_resource_from_handle(self, resource_handle: ResourceHandle) -> Resource: return self.cached_get_resource(resource_handle) - def clear_caches(self): + def clear_caches(self) -> None: self.cached_get_resource.cache_clear() - def get_resource_class(self, resource_key): + def get_resource_class(self, resource_key) -> type[Resource]: resource_class = self.resource_classes.get(resource_key) if resource_class is None: raise ResourceError("Error getting resource from pool: Unknown " "resource type %r" % resource_key) return resource_class - def _get_resource(self, resource_handle): + def _get_resource(self, resource_handle: ResourceHandle) -> Resource: resource_class = self.get_resource_class(resource_handle.key) return resource_class(resource_handle.variables) @@ -254,15 +270,15 @@ class ResourceWrapper(object, metaclass=AttributeForwardMeta): """ keys = None - def __init__(self, resource): + def __init__(self, resource: Resource): self.wrapped = resource @property - def resource(self): + def resource(self) -> Resource: return self.wrapped @property - def handle(self): + def handle(self) -> ResourceHandle: return self.resource.handle @property diff --git a/src/rez/utils/schema.py b/src/rez/utils/schema.py index 36c22380e..1ecefd226 100644 --- a/src/rez/utils/schema.py +++ b/src/rez/utils/schema.py @@ -6,6 +6,7 @@ Utilities for working with dict-based schemas. """ from rez.vendor.schema.schema import Schema, Optional, Use, And +from rez.config import Validatable # an alias which just so happens to be the same number of characters as @@ -68,7 +69,7 @@ def _to(value): d[k] = _to(v) if allow_custom_keys: d[Optional(str)] = modifier or object - schema = Schema(d) + schema: Validatable = Schema(d) elif modifier: schema = And(value, modifier) else: diff --git a/src/rez/utils/sourcecode.py b/src/rez/utils/sourcecode.py index ff835e6ac..149d9bc2f 100644 --- a/src/rez/utils/sourcecode.py +++ b/src/rez/utils/sourcecode.py @@ -2,12 +2,15 @@ # Copyright Contributors to the Rez Project +from __future__ import annotations + from rez.utils.formatting import indent from rez.utils.data_utils import cached_property from rez.utils.logging_ import print_debug from rez.util import load_module_from_file from inspect import getsourcelines from textwrap import dedent +from types import FunctionType, MethodType from glob import glob import traceback import os.path @@ -93,8 +96,8 @@ class SourceCode(object): This object is aware of the decorators defined in this sourcefile (such as 'include') and deals with them appropriately. """ - def __init__(self, source=None, func=None, filepath=None, - eval_as_function=True): + def __init__(self, source: str | None = None, func: FunctionType | MethodType | None = None, + filepath=None, eval_as_function=True): self.source = (source or '').rstrip() self.func = func self.filepath = filepath diff --git a/src/rez/utils/typing.py b/src/rez/utils/typing.py new file mode 100644 index 000000000..61e2b2ba6 --- /dev/null +++ b/src/rez/utils/typing.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright Contributors to the Rez Project + + +from __future__ import absolute_import, print_function + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + # FIXME: use typing.Protocol instead of this workaround when python 3.7 support is dropped + from typing import Protocol + +else: + class Protocol(object): + pass + + +class SupportsLessThan(Protocol): + def __lt__(self, __other: Any) -> bool: + pass diff --git a/src/rez/version/_requirement.py b/src/rez/version/_requirement.py index 9e72a1133..cb9f4aa62 100644 --- a/src/rez/version/_requirement.py +++ b/src/rez/version/_requirement.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright Contributors to the Rez Project - +from __future__ import annotations from rez.version._version import Version, VersionRange from rez.version._util import _Common import re +from typing import Iterator class VersionedObject(_Common): @@ -20,15 +21,18 @@ class VersionedObject(_Common): sep_regex_str = r'[-@#]' sep_regex = re.compile(sep_regex_str) - def __init__(self, s): + def __init__(self, s: str): """ Args: s (str): """ - self.name_ = None - self.version_ = None + self.name_: str + self.version_: Version self.sep_ = '-' + if s is None: + # this is a special case in VersionedObject.construct, but name and version_ + # are always set. return m = self.sep_regex.search(s) @@ -43,20 +47,20 @@ def __init__(self, s): self.version_ = Version() @classmethod - def construct(cls, name, version=None): + def construct(cls, name: str, version: Version | None = None) -> VersionedObject: """Create a VersionedObject directly from an object name and version. Args: name (str): Object name string. version (typing.Optional[Version]): Version object. """ - other = VersionedObject(None) + other = VersionedObject(None) # type: ignore[arg-type] # special case other.name_ = name other.version_ = Version() if version is None else version return other @property - def name(self): + def name(self) -> str: """Name of the object. Returns: @@ -65,7 +69,7 @@ def name(self): return self.name_ @property - def version(self): + def version(self) -> Version: """Version of the object. Returns: @@ -73,7 +77,7 @@ def version(self): """ return self.version_ - def as_exact_requirement(self): + def as_exact_requirement(self) -> str: """Get the versioned object, as an exact requirement string. Returns: @@ -86,15 +90,15 @@ def as_exact_requirement(self): ver_str = str(self.version_) return self.name_ + sep_str + ver_str - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return (isinstance(other, VersionedObject) and (self.name_ == other.name_) and (self.version_ == other.version_)) - def __hash__(self): + def __hash__(self) -> int: return hash((self.name_, self.version_)) - def __str__(self): + def __str__(self) -> str: sep_str = '' ver_str = '' if self.version_: @@ -137,18 +141,20 @@ class Requirement(_Common): """ sep_regex = re.compile(r'[-@#=<>]') - def __init__(self, s, invalid_bound_error=True): + def __init__(self, s: str | None, invalid_bound_error: bool = True): """ Args: s (str): Requirement string invalid_bound_error (bool): If True, raise :exc:`VersionError` if an impossible range is given, such as ``3+<2``. """ - self.name_ = None - self.range_ = None + # there are two constructors where Requirement(None) is called, but they + # both set self.name, so we do not set its value to None here. + self.name_: str + self.range_: VersionRange | None = None self.negate_ = False self.conflict_ = False - self._str = None + self._str: str | None = None self.sep_ = '-' if s is None: return @@ -183,7 +189,7 @@ def __init__(self, s, invalid_bound_error=True): self.range_ = VersionRange() @classmethod - def construct(cls, name, range=None): + def construct(cls, name: str, range: VersionRange | None = None) -> Requirement: """Create a requirement directly from an object name and VersionRange. Args: @@ -197,7 +203,7 @@ def construct(cls, name, range=None): return other @property - def name(self): + def name(self) -> str: """Name of the required object. Returns: @@ -206,7 +212,7 @@ def name(self): return self.name_ @property - def range(self): + def range(self) -> VersionRange: """Version range of the requirement. Returns: @@ -215,7 +221,7 @@ def range(self): return self.range_ @property - def conflict(self): + def conflict(self) -> bool: """True if the requirement is a conflict requirement, eg "!foo", "~foo-1". Returns: @@ -224,7 +230,7 @@ def conflict(self): return self.conflict_ @property - def weak(self): + def weak(self) -> bool: """True if the requirement is weak, eg "~foo". .. note:: @@ -236,7 +242,7 @@ def weak(self): """ return self.negate_ - def safe_str(self): + def safe_str(self) -> str: """Return a string representation that is safe for the current filesystem, and guarantees that no two different Requirement objects will encode to the same value. @@ -246,7 +252,7 @@ def safe_str(self): """ return str(self) - def conflicts_with(self, other): + def conflicts_with(self, other: object) -> bool: """Returns True if this requirement conflicts with another :class:`Requirement` or :class:`VersionedObject`. @@ -264,15 +270,17 @@ def conflicts_with(self, other): return other.range_.issuperset(self.range_) else: return not self.range_.intersects(other.range_) - else: # VersionedObject + elif isinstance(other, VersionedObject): if (self.name_ != other.name_) or (self.range is None): return False if self.conflict: return (other.version_ in self.range_) else: return (other.version_ not in self.range_) + else: + return NotImplemented - def merged(self, other): + def merged(self, other: Requirement) -> Requirement | None: """Merge two requirements. Two requirements can be in conflict and if so, this function returns @@ -335,23 +343,24 @@ def _r(r_): r.range_ = range_ return r - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return (isinstance(other, Requirement) and (self.name_ == other.name_) and (self.range_ == other.range_) and (self.conflict_ == other.conflict_)) - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) - def __str__(self): + def __str__(self) -> str: if self._str is None: pre_str = '~' if self.negate_ else ('!' if self.conflict_ else '') range_str = '' sep_str = '' range_ = self.range_ - if self.negate_: + # Note: the only time that range_ is None is if self.negate_ is True + if self.negate_ or range_ is None: range_ = ~range_ if range_ else VersionRange() if not range_.is_any(): @@ -370,16 +379,16 @@ class RequirementList(_Common): optimal form, merging any requirements for common objects. Order of objects is retained. """ - def __init__(self, requirements): + def __init__(self, requirements: list[Requirement]): """ Args: requirements (list[Requirement]): List of requirements. """ - self.requirements_ = [] - self.conflict_ = None - self.requirements_dict = {} - self.names_ = set() - self.conflict_names_ = set() + self.requirements_: list[Requirement] = [] + self.conflict_: tuple[Requirement, Requirement] | None = None + self.requirements_dict: dict[str, Requirement] = {} + self.names_: set[str] = set() + self.conflict_names_: set[str] = set() for req in requirements: existing_req = self.requirements_dict.get(req.name) @@ -410,7 +419,7 @@ def __init__(self, requirements): self.names_.add(req.name) @property - def requirements(self): + def requirements(self) -> list[Requirement]: """Returns optimised list of requirements, or None if there are conflicts. @@ -420,7 +429,7 @@ def requirements(self): return self.requirements_ @property - def conflict(self): + def conflict(self) -> tuple[Requirement, Requirement] | None: """Get the requirement conflict, if any. Returns: @@ -430,7 +439,7 @@ def conflict(self): return self.conflict_ @property - def names(self): + def names(self) -> set[str]: """Set of names of requirements, not including conflict requirements. Returns: @@ -439,7 +448,7 @@ def names(self): return self.names_ @property - def conflict_names(self): + def conflict_names(self) -> set[str]: """Set of conflict requirement names. Returns: @@ -447,11 +456,11 @@ def conflict_names(self): """ return self.conflict_names_ - def __iter__(self): + def __iter__(self) -> Iterator[Requirement]: for requirement in (self.requirements_ or []): yield requirement - def get(self, name): + def get(self, name: str) -> Requirement | None: """Returns the requirement for the given object, or None. Args: @@ -462,12 +471,12 @@ def get(self, name): """ return self.requirements_dict.get(name) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return (isinstance(other, RequirementList) and (self.requirements_ == other.requirements_) and (self.conflict_ == other.conflict_)) - def __str__(self): + def __str__(self) -> str: if self.conflict_: s1 = str(self.conflict_[0]) s2 = str(self.conflict_[1]) diff --git a/src/rez/version/_version.py b/src/rez/version/_version.py index 42c2caa4a..fee40bab7 100644 --- a/src/rez/version/_version.py +++ b/src/rez/version/_version.py @@ -2,25 +2,33 @@ # Copyright Contributors to the Rez Project +from __future__ import annotations + from rez.version._util import VersionError, ParseException, _Common, \ dedup from bisect import bisect_left import copy import string import re +from typing import cast, Callable, Generic, Iterable, TypeVar, TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import Self + +T = TypeVar("T") re_token = re.compile(r"[a-zA-Z0-9_]+") class _Comparable(_Common): - def __gt__(self, other): + def __gt__(self, other: object) -> bool: return not (self < other or self == other) - def __le__(self, other): + def __le__(self, other: object) -> bool: return self < other or self == other - def __ge__(self, other): + def __ge__(self, other: object) -> bool: return not self < other @@ -28,25 +36,29 @@ class _ReversedComparable(_Common): def __init__(self, value): self.value = value - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, _ReversedComparable): + return NotImplemented return self.value == other.value - def __lt__(self, other): + def __lt__(self, other: object) -> bool: + if not isinstance(other, _ReversedComparable): + return NotImplemented return self.value > other.value - def __gt__(self, other): + def __gt__(self, other: object) -> bool: return not (self < other or self == other) - def __le__(self, other): + def __le__(self, other: object) -> bool: return self < other or self == other - def __ge__(self, other): + def __ge__(self, other: object) -> bool: return not self < other - def __str__(self): + def __str__(self) -> str: return f"reverse({self.value!r})" - def __repr__(self): + def __repr__(self) -> str: return "reverse(%r)" % self.value @@ -60,7 +72,7 @@ class VersionToken(_Comparable): Version tokens are only allowed to contain alphanumerics (any case) and underscores. """ - def __init__(self, token): + def __init__(self, token: str): """ Args: token (str): Token string, eg "rc02" @@ -68,14 +80,14 @@ def __init__(self, token): raise NotImplementedError @classmethod - def create_random_token_string(cls): + def create_random_token_string(cls) -> str: """Create a random token string. For testing purposes only. :meta private: """ raise NotImplementedError - def less_than(self, other): + def less_than(self, other: VersionToken) -> bool: """Compare to another :class:`VersionToken`. Args: @@ -90,7 +102,7 @@ def next(self): """Returns the next largest token.""" raise NotImplementedError - def __str__(self): + def __str__(self) -> str: raise NotImplementedError def __lt__(self, other): @@ -105,26 +117,26 @@ class NumericToken(VersionToken): Version token supporting numbers only. Padding is ignored. """ - def __init__(self, token): + def __init__(self, token: str): if not token.isdigit(): raise VersionError("Invalid version token: '%s'" % token) else: self.n = int(token) @classmethod - def create_random_token_string(cls): + def create_random_token_string(cls) -> str: import random chars = string.digits return ''.join([chars[random.randint(0, len(chars) - 1)] for _ in range(8)]) - def __str__(self): + def __str__(self) -> str: return str(self.n) def __eq__(self, other): return (self.n == other.n) - def less_than(self, other): + def less_than(self, other: NumericToken) -> bool: return (self.n < other.n) def __next__(self): @@ -184,22 +196,23 @@ class AlphanumericVersionToken(VersionToken): numeric_regex = re.compile("[0-9]+") regex = re.compile(r"[a-zA-Z0-9_]+\Z") - def __init__(self, token): + def __init__(self, token: str): if token is None: - self.subtokens = None + # this is a special case used in __next__, and subtokens is always set there + pass elif not self.regex.match(token): raise VersionError("Invalid version token: '%s'" % token) else: self.subtokens = self._parse(token) @classmethod - def create_random_token_string(cls): + def create_random_token_string(cls) -> str: import random chars = string.digits + string.ascii_letters return ''.join([chars[random.randint(0, len(chars) - 1)] for _ in range(8)]) - def __str__(self): + def __str__(self) -> str: return ''.join(map(str, self.subtokens)) def __eq__(self, other): @@ -222,7 +235,7 @@ def next(self): return self.__next__() @classmethod - def _parse(cls, s): + def _parse(cls, s: str) -> list[_SubToken]: subtokens = [] alphas = cls.numeric_regex.split(s) numerics = cls.numeric_regex.findall(s) @@ -272,19 +285,19 @@ class Version(_Comparable): The empty version ``''`` is the smallest possible version, and can be used to represent an unversioned resource. """ - inf = None + inf: Version - def __init__(self, ver_str='', make_token=AlphanumericVersionToken): + def __init__(self, ver_str: str | None = '', make_token=AlphanumericVersionToken): """ Args: ver_str (str): Version string. make_token (typing.Callable[[str], None]): Callable that creates a VersionToken subclass from a string. """ - self.tokens = [] + self.tokens: list[VersionToken] | None = [] self.seps = [] - self._str = None - self._hash = None + self._str: str | None = None + self._hash: int | None = None if ver_str: toks = re_token.findall(ver_str) @@ -304,7 +317,7 @@ def __init__(self, ver_str='', make_token=AlphanumericVersionToken): self.seps = seps[1:-1] - def copy(self): + def copy(self) -> Version: """ Returns a copy of the version. @@ -312,11 +325,13 @@ def copy(self): Version: """ other = Version(None) + if self.tokens is None: + raise RuntimeError("Version.inf cannot be copied") other.tokens = self.tokens[:] other.seps = self.seps[:] return other - def trim(self, len_): + def trim(self, len_: int) -> Version: """Return a copy of the version, possibly with less tokens. Args: @@ -327,11 +342,13 @@ def trim(self, len_): Version: """ other = Version(None) + if self.tokens is None: + raise RuntimeError("Version.inf cannot be trimmed") other.tokens = self.tokens[:len_] other.seps = self.seps[:len_ - 1] return other - def __next__(self): + def __next__(self) -> Version: """Return :meth:`next` version. Eg, ``next(1.2)`` is ``1.2_``""" if self.tokens: other = self.copy() @@ -341,11 +358,11 @@ def __next__(self): else: return Version.inf - def next(self): + def next(self) -> Version: return self.__next__() @property - def major(self): + def major(self) -> VersionToken: """Semantic versioning major version. Returns: @@ -354,7 +371,7 @@ def major(self): return self[0] @property - def minor(self): + def minor(self) -> VersionToken: """Semantic versioning minor version. Returns: @@ -363,7 +380,7 @@ def minor(self): return self[1] @property - def patch(self): + def patch(self) -> VersionToken: """Semantic versioning patch version. Returns: @@ -371,7 +388,7 @@ def patch(self): """ return self[2] - def as_tuple(self): + def as_tuple(self) -> tuple[str, ...]: """Convert to a tuple of strings. Example: @@ -382,25 +399,31 @@ def as_tuple(self): Returns: tuple[str]: """ + if self.tokens is None: + # Version.inf + return () return tuple(map(str, self.tokens)) - def __len__(self): + def __len__(self) -> int: return len(self.tokens or []) - def __getitem__(self, index): + def __getitem__(self, index: int) -> VersionToken: try: return (self.tokens or [])[index] except IndexError: raise IndexError("version token index out of range") - def __bool__(self): + def __bool__(self) -> bool: """The empty version equates to False.""" return bool(self.tokens) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return isinstance(other, Version) and self.tokens == other.tokens - def __lt__(self, other): + def __lt__(self, other: object) -> bool: + if not isinstance(other, Version): + return NotImplemented + if self.tokens is None: return False elif other.tokens is None: @@ -408,13 +431,13 @@ def __lt__(self, other): else: return (self.tokens < other.tokens) - def __hash__(self): + def __hash__(self) -> int: if self._hash is None: self._hash = hash(None) if self.tokens is None \ else hash(tuple(map(str, self.tokens))) return self._hash - def __str__(self): + def __str__(self) -> str: if self._str is None: self._str = "[INF]" if self.tokens is None \ else ''.join(str(x) + y for x, y in zip(self.tokens, self.seps + [''])) @@ -427,13 +450,13 @@ def __str__(self): class _LowerBound(_Comparable): - min = None + min: _LowerBound - def __init__(self, version, inclusive): + def __init__(self, version: Version, inclusive: bool): self.version = version self.inclusive = inclusive - def __str__(self): + def __str__(self) -> str: if self.version: s = "%s+" if self.inclusive else ">%s" return s % self.version @@ -449,10 +472,10 @@ def __lt__(self, other): or ((self.version == other.version) and (self.inclusive and not other.inclusive)) - def __hash__(self): + def __hash__(self) -> int: return hash((self.version, self.inclusive)) - def contains_version(self, version): + def contains_version(self, version: Version) -> bool: return (version > self.version) \ or (self.inclusive and (version == self.version)) @@ -461,15 +484,15 @@ def contains_version(self, version): class _UpperBound(_Comparable): - inf = None + inf: _UpperBound - def __init__(self, version, inclusive): + def __init__(self, version: Version, inclusive: bool): self.version = version self.inclusive = inclusive if not version and not inclusive: raise VersionError("Invalid upper bound: '%s'" % str(self)) - def __str__(self): + def __str__(self) -> str: s = "<=%s" if self.inclusive else "<%s" return s % self.version @@ -482,10 +505,10 @@ def __lt__(self, other): or ((self.version == other.version) and (not self.inclusive and other.inclusive)) - def __hash__(self): + def __hash__(self) -> int: return hash((self.version, self.inclusive)) - def contains_version(self, version): + def contains_version(self, version: Version) -> bool: return (version < self.version) \ or (self.inclusive and (version == self.version)) @@ -494,9 +517,11 @@ def contains_version(self, version): class _Bound(_Comparable): - any = None + any: _Bound - def __init__(self, lower=None, upper=None, invalid_bound_error=True): + def __init__(self, lower: _LowerBound | None = None, + upper: _UpperBound | None = None, + invalid_bound_error: bool = True): self.lower = lower or _LowerBound.min self.upper = upper or _UpperBound.inf @@ -509,7 +534,7 @@ def __init__(self, lower=None, upper=None, invalid_bound_error=True): ): raise VersionError("Invalid bound") - def __str__(self): + def __str__(self) -> str: if self.upper.version == Version.inf: return str(self.lower) elif self.lower.version == self.upper.version: @@ -534,26 +559,26 @@ def __lt__(self, other): def __hash__(self): return hash((self.lower, self.upper)) - def lower_bounded(self): + def lower_bounded(self) -> bool: return (self.lower != _LowerBound.min) - def upper_bounded(self): + def upper_bounded(self) -> bool: return (self.upper != _UpperBound.inf) - def contains_version(self, version): + def contains_version(self, version: Version) -> bool: return (self.version_containment(version) == 0) - def version_containment(self, version): + def version_containment(self, version: Version) -> int: if not self.lower.contains_version(version): return -1 if not self.upper.contains_version(version): return 1 return 0 - def contains_bound(self, bound): + def contains_bound(self, bound: _Bound) -> bool: return (self.lower <= bound.lower) and (self.upper >= bound.upper) - def intersects(self, other): + def intersects(self, other: _Bound) -> bool: lower = max(self.lower, other.lower) upper = min(self.upper, other.upper) @@ -561,7 +586,7 @@ def intersects(self, other): (lower.version == upper.version) and (lower.inclusive and upper.inclusive) ) - def intersection(self, other): + def intersection(self, other: _Bound) -> _Bound | None: lower = max(self.lower, other.lower) upper = min(self.upper, other.upper) @@ -576,6 +601,20 @@ def intersection(self, other): _Bound.any = _Bound() +def action(fn): + def fn_(self): + result = fn(self) + if self.debug: + label = fn.__name__.replace("_act_", "") + print("%-21s: %s" % (label, self._input_string)) + for key, value in self._groups.items(): + print(" %-17s= %s" % (key, value)) + print(" %-17s= %s" % ("bounds", self.bounds)) + return result + + return fn_ + + class _VersionRangeParser(object): debug = False # set to True to enable parser debugging @@ -659,7 +698,7 @@ class _VersionRangeParser(object): regex = re.compile(version_range_regex, re_flags) - def __init__(self, input_string, make_token, invalid_bound_error=True): + def __init__(self, input_string: str, make_token, invalid_bound_error=True): self.make_token = make_token self._groups = {} self._input_string = input_string @@ -712,29 +751,17 @@ def __init__(self, input_string, make_token, invalid_bound_error=True): elif self._groups['upper_bound']: self._act_upper_bound() - def _is_lower_bound_exclusive(self, token): + def _is_lower_bound_exclusive(self, token: str) -> bool: return (token == ">") - def _is_upper_bound_exclusive(self, token): + def _is_upper_bound_exclusive(self, token: str) -> bool: return (token == "<") - def _create_version_from_token(self, token): + def _create_version_from_token(self, token: str) -> Version: return Version(token, make_token=self.make_token) - def action(fn): - def fn_(self): - result = fn(self) - if self.debug: - label = fn.__name__.replace("_act_", "") - print("%-21s: %s" % (label, self._input_string)) - for key, value in self._groups.items(): - print(" %-17s= %s" % (key, value)) - print(" %-17s= %s" % ("bounds", self.bounds)) - return result - return fn_ - @action - def _act_version(self): + def _act_version(self) -> None: version = self._create_version_from_token(self._groups['version']) lower_bound = _LowerBound(version, True) upper_bound = _UpperBound(version.next(), False) if version else None @@ -742,7 +769,7 @@ def _act_version(self): self.bounds.append(_Bound(lower_bound, upper_bound)) @action - def _act_exact_version(self): + def _act_exact_version(self) -> None: version = self._create_version_from_token(self._groups['exact_version_group']) lower_bound = _LowerBound(version, True) upper_bound = _UpperBound(version, True) @@ -750,7 +777,7 @@ def _act_exact_version(self): self.bounds.append(_Bound(lower_bound, upper_bound)) @action - def _act_bound(self): + def _act_bound(self) -> None: lower_version = self._create_version_from_token(self._groups['inclusive_lower_version']) lower_bound = _LowerBound(lower_version, True) @@ -760,7 +787,7 @@ def _act_bound(self): self.bounds.append(_Bound(lower_bound, upper_bound, self.invalid_bound_error)) @action - def _act_lower_bound(self): + def _act_lower_bound(self) -> None: version = self._create_version_from_token(self._groups['lower_version']) exclusive = self._is_lower_bound_exclusive(self._groups['lower_bound_prefix']) lower_bound = _LowerBound(version, not exclusive) @@ -768,7 +795,7 @@ def _act_lower_bound(self): self.bounds.append(_Bound(lower_bound, None)) @action - def _act_upper_bound(self): + def _act_upper_bound(self) -> None: version = self._create_version_from_token(self._groups['upper_version']) exclusive = self._is_upper_bound_exclusive(self._groups['upper_bound_prefix']) upper_bound = _UpperBound(version, not exclusive) @@ -776,7 +803,7 @@ def _act_upper_bound(self): self.bounds.append(_Bound(None, upper_bound)) @action - def _act_lower_and_upper_bound_asc(self): + def _act_lower_and_upper_bound_asc(self) -> None: lower_bound = None upper_bound = None @@ -793,7 +820,7 @@ def _act_lower_and_upper_bound_asc(self): self.bounds.append(_Bound(lower_bound, upper_bound, self.invalid_bound_error)) @action - def _act_lower_and_upper_bound_desc(self): + def _act_lower_and_upper_bound_desc(self) -> None: lower_bound = None upper_bound = None @@ -867,8 +894,9 @@ class VersionRange(_Comparable): valid version range syntax. For example, ``>`` is a valid range - read like ``>''``, it means ``any version greater than the empty version``. """ - def __init__(self, range_str='', make_token=AlphanumericVersionToken, - invalid_bound_error=True): + def __init__(self, range_str: str | None = '', + make_token: type[VersionToken] = AlphanumericVersionToken, + invalid_bound_error: bool = True): """ Args: range_str (str): Range string, such as "3", "3+<4.5", "2|6+". The range @@ -878,7 +906,7 @@ def __init__(self, range_str='', make_token=AlphanumericVersionToken, invalid_bound_error (bool): If True, raise an exception if an impossible range is given, such as '3+<2'. """ - self._str = None + self._str: str | None = None self.bounds = [] # note: kept in ascending order if range_str is None: return @@ -899,7 +927,7 @@ def __init__(self, range_str='', make_token=AlphanumericVersionToken, else: self.bounds.append(_Bound.any) - def is_any(self): + def is_any(self) -> bool: """ Returns: bool: True if this is the "any" range, ie the empty string range @@ -907,7 +935,7 @@ def is_any(self): """ return (len(self.bounds) == 1) and (self.bounds[0] == _Bound.any) - def lower_bounded(self): + def lower_bounded(self) -> bool: """ Returns: bool: True if the range has a lower bound (that is not the empty @@ -915,35 +943,35 @@ def lower_bounded(self): """ return self.bounds[0].lower_bounded() - def upper_bounded(self): + def upper_bounded(self) -> bool: """ Returns: bool: True if the range has an upper bound. """ return self.bounds[-1].upper_bounded() - def bounded(self): + def bounded(self) -> bool: """ Returns: bool: True if the range has a lower and upper bound. """ return (self.lower_bounded() and self.upper_bounded()) - def issuperset(self, range): + def issuperset(self, range) -> bool: """ Returns: bool: True if the VersionRange is contained within this range. """ return self._issuperset(self.bounds, range.bounds) - def issubset(self, range): + def issubset(self, range) -> bool: """ Returns: bool: True if we are contained within the version range. """ return range.issuperset(self) - def union(self, other): + def union(self, other: VersionRange | Iterable[VersionRange]) -> VersionRange: """OR together version ranges. Calculates the union of this range with one or more other ranges. @@ -965,7 +993,7 @@ def union(self, other): range.bounds = bounds return range - def intersection(self, other): + def intersection(self, other: VersionRange | Iterable[VersionRange]) -> VersionRange | None: """AND together version ranges. Calculates the intersection of this range with one or more other ranges. @@ -990,7 +1018,7 @@ def intersection(self, other): range.bounds = bounds return range - def inverse(self): + def inverse(self) -> VersionRange | None: """Calculate the inverse of the range. Returns: @@ -1005,7 +1033,7 @@ def inverse(self): range.bounds = bounds return range - def intersects(self, other): + def intersects(self, other: VersionRange) -> bool: """Determine if we intersect with another range. Args: @@ -1016,7 +1044,7 @@ def intersects(self, other): """ return self._intersects(self.bounds, other.bounds) - def split(self): + def split(self) -> list[VersionRange]: """Split into separate contiguous ranges. Returns: @@ -1031,8 +1059,11 @@ def split(self): return ranges @classmethod - def as_span(cls, lower_version=None, upper_version=None, - lower_inclusive=True, upper_inclusive=True): + def as_span(cls, + lower_version: Version | None = None, + upper_version: Version | None = None, + lower_inclusive=True, + upper_inclusive=True): """Create a range from lower_version..upper_version. Args: @@ -1054,7 +1085,7 @@ def as_span(cls, lower_version=None, upper_version=None, return range @classmethod - def from_version(cls, version, op=None): + def from_version(cls, version: Version, op: str | None = None) -> Self: """Create a range from a version. Args: @@ -1093,7 +1124,7 @@ def from_version(cls, version, op=None): return range @classmethod - def from_versions(cls, versions): + def from_versions(cls, versions: Iterable[Version]) -> VersionRange: """Create a range from a list of versions. This method creates a range that contains only the given versions and @@ -1114,7 +1145,7 @@ def from_versions(cls, versions): range.bounds.append(bound) return range - def to_versions(self): + def to_versions(self) -> list[Version] | None: """Returns exact version ranges as Version objects, or None if there are no exact version ranges present. @@ -1129,7 +1160,7 @@ def to_versions(self): return versions or None - def contains_version(self, version): + def contains_version(self, version: Version) -> bool: """Returns True if version is contained in this range. Returns: @@ -1149,7 +1180,9 @@ def contains_version(self, version): return False - def iter_intersect_test(self, iterable, key=None, descending=False): + def iter_intersect_test(self, iterable: Iterable[T], + key: Callable[[T], Version] | None = None, + descending: bool = False) -> _ContainsVersionIterator[T]: """Performs containment tests on a sorted list of versions. This is more optimal than performing separate containment tests on a @@ -1170,7 +1203,9 @@ def iter_intersect_test(self, iterable, key=None, descending=False): """ return _ContainsVersionIterator(self, iterable, key, descending) - def iter_intersecting(self, iterable, key=None, descending=False): + def iter_intersecting(self, iterable: Iterable[T], + key: Callable[[T], Version] | None = None, + descending: bool = False) -> _ContainsVersionIterator[T]: """Like :meth:iter_intersect_test`, but returns intersections only. Returns: @@ -1180,7 +1215,9 @@ def iter_intersecting(self, iterable, key=None, descending=False): self, iterable, key, descending, mode=_ContainsVersionIterator.MODE_INTERSECTING ) - def iter_non_intersecting(self, iterable, key=None, descending=False): + def iter_non_intersecting(self, iterable: Iterable[T], + key: Callable[[T], Version] | None = None, + descending: bool = False) -> _ContainsVersionIterator[T]: """Like :meth:`iter_intersect_test`, but returns non-intersections only. Returns: @@ -1190,7 +1227,7 @@ def iter_non_intersecting(self, iterable, key=None, descending=False): self, iterable, key, descending, mode=_ContainsVersionIterator.MODE_NON_INTERSECTING ) - def span(self): + def span(self) -> VersionRange: """Return a contiguous range that is a superset of this range. Returns: @@ -1236,32 +1273,32 @@ def visit_versions(self, func): if isinstance(result, Version): bound.upper.version = result - def __contains__(self, version_or_range): + def __contains__(self, version_or_range: Version | VersionRange) -> bool: if isinstance(version_or_range, Version): return self.contains_version(version_or_range) else: return self.issuperset(version_or_range) - def __len__(self): + def __len__(self) -> int: return len(self.bounds) - def __invert__(self): + def __invert__(self) -> VersionRange | None: return self.inverse() - def __and__(self, other): + def __and__(self, other) -> VersionRange | None: return self.intersection(other) - def __or__(self, other): + def __or__(self, other) -> VersionRange: return self.union(other) - def __add__(self, other): + def __add__(self, other) -> VersionRange: return self.union(other) - def __sub__(self, other): + def __sub__(self, other) -> VersionRange | None: inv = other.inverse() return None if inv is None else self.intersection(inv) - def __str__(self): + def __str__(self) -> str: if self._str is None: self._str = '|'.join(map(str, self.bounds)) return self._str @@ -1272,10 +1309,10 @@ def __eq__(self, other): def __lt__(self, other): return (self.bounds < other.bounds) - def __hash__(self): + def __hash__(self) -> int: return hash(tuple(self.bounds)) - def _contains_version(self, version): + def _contains_version(self, version: Version) -> tuple[int, bool]: vbound = _Bound(_LowerBound(version, True)) i = bisect_left(self.bounds, vbound) if i and self.bounds[i - 1].contains_version(version): @@ -1285,14 +1322,14 @@ def _contains_version(self, version): return i, False @classmethod - def _union(cls, bounds): + def _union(cls, bounds: list[_Bound]) -> list[_Bound]: if len(bounds) < 2: return bounds bounds_ = list(sorted(bounds)) new_bounds = [] - prev_bound = None - upper = None + prev_bound: _Bound | None = None + upper: _UpperBound | None = None start = 0 for i, bound in enumerate(bounds_): @@ -1312,7 +1349,7 @@ def _union(cls, bounds): return new_bounds @classmethod - def _intersection(cls, bounds1, bounds2): + def _intersection(cls, bounds1: list[_Bound], bounds2: list[_Bound]) -> list[_Bound]: new_bounds = [] for bound1 in bounds1: for bound2 in bounds2: @@ -1322,22 +1359,22 @@ def _intersection(cls, bounds1, bounds2): return new_bounds @classmethod - def _inverse(cls, bounds): - lbounds = [None] - ubounds = [] + def _inverse(cls, bounds: list[_Bound]) -> list[_Bound]: + lbounds: list[_LowerBound | None] = [None] + ubounds: list[_UpperBound | None] = [] for bound in bounds: if not bound.lower.version and bound.lower.inclusive: ubounds.append(None) else: - b = _UpperBound(bound.lower.version, not bound.lower.inclusive) - ubounds.append(b) + ub = _UpperBound(bound.lower.version, not bound.lower.inclusive) + ubounds.append(ub) if bound.upper.version == Version.inf: lbounds.append(None) else: - b = _LowerBound(bound.upper.version, not bound.upper.inclusive) - lbounds.append(b) + lb = _LowerBound(bound.upper.version, not bound.upper.inclusive) + lbounds.append(lb) ubounds.append(None) new_bounds = [] @@ -1349,7 +1386,7 @@ def _inverse(cls, bounds): return new_bounds @classmethod - def _issuperset(cls, bounds1, bounds2): + def _issuperset(cls, bounds1: list[_Bound], bounds2: list[_Bound]) -> bool: lo = 0 for bound2 in bounds2: i = bisect_left(bounds1, bound2, lo=lo) @@ -1364,7 +1401,7 @@ def _issuperset(cls, bounds1, bounds2): return True @classmethod - def _intersects(cls, bounds1, bounds2): + def _intersects(cls, bounds1: list[_Bound], bounds2: list[_Bound]) -> bool: # sort so bounds1 is the shorter list bounds1, bounds2 = sorted((bounds1, bounds2), key=lambda x: len(x)) @@ -1388,23 +1425,27 @@ def _intersects(cls, bounds1, bounds2): return False -class _ContainsVersionIterator(object): +class _ContainsVersionIterator(Generic[T]): MODE_INTERSECTING = 0 MODE_NON_INTERSECTING = 2 MODE_ALL = 3 - def __init__(self, range_, iterable, key=None, descending=False, mode=MODE_ALL): + def __init__(self, range_: VersionRange, iterable: Iterable[T], + key: Callable[[T], Version] | None = None, + descending: bool = False, mode=MODE_ALL): self.mode = mode self.range_ = range_ - self.index = None + self.index: int | None = None self.nbounds = len(self.range_.bounds) self._constant = True if range_.is_any() else None self.fn = self._descending if descending else self._ascending self.it = iter(iterable) if key is None: - key = lambda x: x # noqa: E731 + # FIXME: this case seems to assume that iterable is Iterable[Version] + key = cast(Callable[[T], Version], lambda x: x) # noqa: E731 self.keyfunc = key + self.next_fn: Callable[[], tuple[bool, T]] | Callable[[], T] if mode == self.MODE_ALL: self.next_fn = self._next elif mode == self.MODE_INTERSECTING: @@ -1412,16 +1453,16 @@ def __init__(self, range_, iterable, key=None, descending=False, mode=MODE_ALL): else: self.next_fn = self._next_non_intersecting - def __iter__(self): + def __iter__(self) -> _ContainsVersionIterator[T]: return self - def __next__(self): + def __next__(self) -> T | tuple[bool, T]: return self.next_fn() - def next(self): + def next(self) -> T | tuple[bool, T]: return self.next_fn() - def _next(self): + def _next(self) -> tuple[bool, T]: value = next(self.it) if self._constant is not None: return self._constant, value @@ -1430,7 +1471,7 @@ def _next(self): intersects = self.fn(version) return intersects, value - def _next_intersecting(self): + def _next_intersecting(self) -> T: while True: value = next(self.it) @@ -1444,7 +1485,7 @@ def _next_intersecting(self): if intersects: return value - def _next_non_intersecting(self): + def _next_non_intersecting(self) -> T: while True: value = next(self.it) @@ -1459,13 +1500,13 @@ def _next_non_intersecting(self): return value @property - def _bound(self): + def _bound(self) -> _Bound | None: if self.index < self.nbounds: return self.range_.bounds[self.index] else: return None - def _ascending(self, version): + def _ascending(self, version: Version) -> bool: if self.index is None: self.index, contains = self.range_._contains_version(version) bound = self._bound @@ -1501,7 +1542,7 @@ def _ascending(self, version): elif j == -1: return False - def _descending(self, version): + def _descending(self, version: Version) -> bool: if self.index is None: self.index, contains = self.range_._contains_version(version) bound = self._bound diff --git a/src/rezplugins/build_process/local.py b/src/rezplugins/build_process/local.py index 4fa1b8277..37c982500 100644 --- a/src/rezplugins/build_process/local.py +++ b/src/rezplugins/build_process/local.py @@ -5,6 +5,8 @@ """ Builds packages on local host """ +from __future__ import annotations + from rez.config import config from rez.package_repository import package_repository_manager from rez.build_process import BuildProcessHelper, BuildType @@ -21,11 +23,16 @@ from rez.package_test import PackageTestRunner, PackageTestResults from hashlib import sha1 +from typing import TYPE_CHECKING import json import shutil import os import os.path +if TYPE_CHECKING: + from rez.packages import Variant + from rez.build_system import BuildResult + class LocalBuildProcess(BuildProcessHelper): """The default build process. @@ -45,7 +52,11 @@ def __init__(self, *nargs, **kwargs): self.ran_test_names = set() self.all_test_results = PackageTestResults() - def build(self, install_path=None, clean=False, install=False, variants=None): + def build(self, + install_path: str | None = None, + clean: bool = False, + install: bool = False, + variants: list[int] | None = None) -> int: self._print_header("Building %s..." % self.package.qualified_name) # build variants @@ -71,7 +82,7 @@ def build(self, install_path=None, clean=False, install=False, variants=None): return num_visited - def release(self, release_message=None, variants=None): + def release(self, release_message=None, variants: list[int] | None = None): self._print_header("Releasing %s..." % self.package.qualified_name) # test that we're in a state to release @@ -130,8 +141,13 @@ def release(self, release_message=None, variants=None): return num_released - def _build_variant_base(self, variant, build_type, install_path=None, - clean=False, install=False, **kwargs): + def _build_variant_base(self, + variant: Variant, + build_type, + install_path: str | None = None, + clean=False, + install=False, + **kwargs) -> BuildResult: # create build/install paths install_path = install_path or self.package.config.local_packages_path package_install_path = self.get_package_install_path(install_path) @@ -284,7 +300,7 @@ def _build_variant_base(self, variant, build_type, install_path=None, return build_result - def _install_include_modules(self, install_path): + def _install_include_modules(self, install_path: str) -> None: # install 'include' sourcefiles, used by funcs decorated with @include if not self.package.includes: return @@ -317,8 +333,12 @@ def _rmtree(self, path): except Exception as e: print_warning("Failed to delete %s - %s", path, e) - def _build_variant(self, variant, install_path=None, clean=False, - install=False, **kwargs): + def _build_variant(self, + variant: Variant, + install_path: str | None = None, + clean: bool = False, + install: bool = False, + **kwargs) -> str | None: if variant.index is not None: self._print_header( "Building variant %s (%s)..." @@ -367,7 +387,7 @@ def cancel_variant_install(): return build_result.get("build_env_script") - def _release_variant(self, variant, release_message=None, **kwargs): + def _release_variant(self, variant: Variant, release_message=None, **kwargs): release_path = self.package.config.release_packages_path # test if variant has already been released diff --git a/src/rezplugins/build_system/cmake.py b/src/rezplugins/build_system/cmake.py index 6a76d92de..a55f28888 100644 --- a/src/rezplugins/build_system/cmake.py +++ b/src/rezplugins/build_system/cmake.py @@ -5,17 +5,20 @@ """ CMake-based build system """ -from rez.build_system import BuildSystem +from __future__ import annotations + +from rez.build_system import BuildSystem, BuildResult from rez.build_process import BuildType from rez.resolved_context import ResolvedContext from rez.exceptions import BuildSystemError from rez.utils.execution import create_forwarding_script -from rez.packages import get_developer_package +from rez.packages import get_developer_package, Variant from rez.utils.platform_ import platform_ from rez.config import config from rez.utils.which import which from rez.vendor.schema.schema import Or from rez.shells import create_shell +import argparse import functools import os.path import sys @@ -62,11 +65,11 @@ class CMakeBuildSystem(BuildSystem): } @classmethod - def name(cls): + def name(cls) -> str: return "cmake" @classmethod - def child_build_system(cls): + def child_build_system(cls) -> str: return "make" @classmethod @@ -74,7 +77,7 @@ def is_valid_root(cls, path, package=None): return os.path.isfile(os.path.join(path, "CMakeLists.txt")) @classmethod - def bind_cli(cls, parser, group): + def bind_cli(cls, parser: argparse.ArgumentParser, group: argparse._ArgumentGroup): settings = config.plugins.build_system.cmake group.add_argument("--bt", "--build-target", dest="build_target", type=str, choices=cls.build_targets, @@ -86,7 +89,7 @@ def bind_cli(cls, parser, group): default=settings.build_system, help="set the cmake build system (default: %(default)s).") - def __init__(self, working_dir, opts=None, package=None, write_build_scripts=False, + def __init__(self, working_dir: str, opts=None, package=None, write_build_scripts=False, verbose=False, build_args=[], child_build_args=[]): super(CMakeBuildSystem, self).__init__( working_dir, @@ -105,8 +108,13 @@ def __init__(self, working_dir, opts=None, package=None, write_build_scripts=Fal raise RezCMakeError("Generation of Xcode project only available " "on the OSX platform") - def build(self, context, variant, build_path, install_path, install=False, - build_type=BuildType.local): + def build(self, + context: ResolvedContext, + variant: Variant, + build_path: str, + install_path: str, + install: bool = False, + build_type=BuildType.local) -> BuildResult: def _pr(s): if self.verbose: print(s) @@ -175,7 +183,7 @@ def _pr(s): post_actions_callback=post_actions_callback ) - ret = {} + ret = BuildResult() if retcode: ret["success"] = False return ret diff --git a/src/rezplugins/build_system/custom.py b/src/rezplugins/build_system/custom.py index e1f29e765..284fbbe67 100644 --- a/src/rezplugins/build_system/custom.py +++ b/src/rezplugins/build_system/custom.py @@ -5,16 +5,19 @@ """ Package-defined build command """ +from __future__ import annotations + from shlex import quote +from typing import TYPE_CHECKING import functools import os.path import sys import os -from rez.build_system import BuildSystem +from rez.build_system import BuildSystem, BuildResult from rez.build_process import BuildType from rez.utils.execution import create_forwarding_script -from rez.packages import get_developer_package +from rez.packages import get_developer_package, Variant from rez.resolved_context import ResolvedContext from rez.shells import create_shell from rez.exceptions import PackageMetadataError @@ -22,6 +25,10 @@ from rez.utils.logging_ import print_warning from rez.config import config +if TYPE_CHECKING: + import argparse + from rez.rex import RexExecutor + class CustomBuildSystem(BuildSystem): """This build system runs the 'build_command' defined in a package.py. @@ -73,7 +80,7 @@ def __init__(self, working_dir, opts=None, package=None, write_build_scripts=Fal child_build_args=child_build_args) @classmethod - def bind_cli(cls, parser, group): + def bind_cli(cls, parser: argparse.ArgumentParser, group: argparse._ArgumentGroup): """ Uses a 'parse_build_args.py' file to add options, if found. """ @@ -97,15 +104,20 @@ def bind_cli(cls, parser, group): # store extra args onto parser so we can get to it in self.build() setattr(parser, "_rezbuild_extra_args", list(extra_args)) - def build(self, context, variant, build_path, install_path, install=False, - build_type=BuildType.local): + def build(self, + context: ResolvedContext, + variant: Variant, + build_path: str, + install_path: str, + install: bool = False, + build_type=BuildType.local) -> BuildResult: """Perform the build. Note that most of the func args aren't used here - that's because this info is already passed to the custom build command via environment variables. """ - ret = {} + ret = BuildResult() if self.write_build_scripts: # write out the script that places the user in a build env @@ -215,7 +227,7 @@ def _actions_callback(executor): return ret @classmethod - def _add_build_actions(cls, executor, context, package, variant, + def _add_build_actions(cls, executor: RexExecutor, context: ResolvedContext, package, variant, build_type, install, build_path, install_path=None): cls.add_standard_build_actions( executor=executor, diff --git a/src/rezplugins/package_repository/filesystem.py b/src/rezplugins/package_repository/filesystem.py index d0b793b9d..ebb8b8121 100644 --- a/src/rezplugins/package_repository/filesystem.py +++ b/src/rezplugins/package_repository/filesystem.py @@ -5,6 +5,8 @@ """ Filesystem-based package repository """ +from __future__ import annotations + from contextlib import contextmanager from functools import lru_cache import os.path @@ -36,6 +38,11 @@ from rez.vendor.schema.schema import Schema, Optional, And, Use, Or from rez.version import Version, VersionRange +from typing import Iterator, Iterable, TYPE_CHECKING + +if TYPE_CHECKING: + from rez.packages import Package, Variant, PackageRepositoryResourceWrapper + from rez.package_resources import PackageRepositoryResource, VariantResource debug_print = config.debug_printer("resources") @@ -88,14 +95,14 @@ class FileSystemPackageFamilyResource(PackageFamilyResource): key = "filesystem.family" repository_type = "filesystem" - def _uri(self): + def _uri(self) -> str: return self.path @cached_property - def path(self): + def path(self) -> str: return os.path.join(self.location, self.name) - def get_last_release_time(self): + def get_last_release_time(self) -> float: # this repository makes sure to update path mtime every time a # variant is added to the repository try: @@ -103,7 +110,7 @@ def get_last_release_time(self): except OSError: return 0 - def iter_packages(self): + def iter_packages(self) -> Iterator[FileSystemPackageResource]: # check for unversioned package if config.allow_unversioned_packages: filepath, _ = self._repository._get_file(self.path) @@ -137,11 +144,11 @@ class FileSystemPackageResource(PackageResourceHelper): repository_type = "filesystem" schema = package_pod_schema - def _uri(self): + def _uri(self) -> str: return self.filepath @cached_property - def parent(self): + def parent(self) -> FileSystemPackageFamilyResource: family = self._repository.get_resource( FileSystemPackageFamilyResource.key, location=self.location, @@ -149,13 +156,13 @@ def parent(self): return family @cached_property - def state_handle(self): + def state_handle(self) -> float | None: if self.filepath: return os.path.getmtime(self.filepath) return None @property - def base(self): + def base(self) -> str | None: # Note: '_redirected_base' is a special attribute set by the build # process in order to perform pre-install/release package testing. See # `LocalBuildProcess._run_tests()` @@ -165,7 +172,7 @@ def base(self): return redirected_base or self.path @cached_property - def path(self): + def path(self) -> str: path = os.path.join(self.location, self.name) ver_str = self.get("version") if ver_str: @@ -173,7 +180,7 @@ def path(self): return path @cached_property - def filepath(self): + def filepath(self) -> str: return self._filepath_and_format[0] @cached_property @@ -268,7 +275,7 @@ class FileSystemVariantResource(VariantResourceHelper): repository_type = "filesystem" @cached_property - def parent(self): + def parent(self) -> FileSystemPackageResource: package = self._repository.get_resource( FileSystemPackageResource.key, location=self.location, @@ -350,12 +357,12 @@ class FileSystemCombinedPackageResource(PackageResourceHelper): repository_type = "filesystem" schema = package_pod_schema - def _uri(self): + def _uri(self) -> str: ver_str = self.get("version", "") return "%s<%s>" % (self.parent.filepath, ver_str) @cached_property - def parent(self): + def parent(self) -> FileSystemCombinedPackageFamilyResource: family = self._repository.get_resource( FileSystemCombinedPackageFamilyResource.key, location=self.location, @@ -364,17 +371,17 @@ def parent(self): return family @property - def base(self): + def base(self) -> str | None: return None # combined resource types do not have 'base' @cached_property - def state_handle(self): + def state_handle(self) -> float: return os.path.getmtime(self.parent.filepath) - def iter_variants(self): + def iter_variants(self) -> Iterator[FileSystemCombinedVariantResource]: num_variants = len(self._data.get("variants", [])) if num_variants == 0: - indexes = [None] + indexes: Iterable[int | None] = [None] else: indexes = range(num_variants) @@ -412,7 +419,7 @@ class FileSystemCombinedVariantResource(VariantResourceHelper): repository_type = "filesystem" @cached_property - def parent(self): + def parent(self) -> PackageRepositoryResource: package = self._repository.get_resource( FileSystemCombinedPackageResource.key, location=self.location, @@ -421,7 +428,7 @@ def parent(self): version=self.get("version")) return package - def _root(self): + def _root(self, ignore_shortlinks: bool = False) -> str | None: return None # combined resource types do not have 'root' @@ -557,37 +564,37 @@ def _uid(self): t.append(int(st.st_ino)) return tuple(t) - def get_package_family(self, name): + def get_package_family(self, name: str) -> PackageFamilyResource: return self.get_family(name) @pool_memcached_connections - def iter_package_families(self): + def iter_package_families(self) -> Iterator[PackageFamilyResource]: for family in self.get_families(): yield family @pool_memcached_connections - def iter_packages(self, package_family_resource): + def iter_packages(self, package_family_resource: PackageFamilyResource) -> Iterator[Package]: for package in self.get_packages(package_family_resource): yield package - def iter_variants(self, package_resource): + def iter_variants(self, package_resource: PackageResourceHelper) -> Iterator[VariantResource]: for variant in self.get_variants(package_resource): yield variant - def get_parent_package_family(self, package_resource): + def get_parent_package_family(self, package_resource: PackageResourceHelper) -> PackageRepositoryResource: return package_resource.parent - def get_parent_package(self, variant_resource): + def get_parent_package(self, variant_resource: VariantResource) -> PackageRepositoryResource: return variant_resource.parent - def get_variant_state_handle(self, variant_resource): + def get_variant_state_handle(self, variant_resource: VariantResource): package_resource = variant_resource.parent return package_resource.state_handle - def get_last_release_time(self, package_family_resource): + def get_last_release_time(self, package_family_resource: PackageFamilyResource): return package_family_resource.get_last_release_time() - def get_package_from_uri(self, uri): + def get_package_from_uri(self, uri: str): """ Example URIs: - /svr/packages/mypkg/1.0.0/package.py @@ -621,7 +628,7 @@ def get_package_from_uri(self, uri): pkg_ver = Version(pkg_ver_str) return self.get_package(pkg_name, pkg_ver) - def get_variant_from_uri(self, uri): + def get_variant_from_uri(self, uri: str) -> Variant | None: """ Example URIs: - /svr/packages/mypkg/1.0.0/package.py[1] @@ -657,7 +664,7 @@ def get_variant_from_uri(self, uri): return None - def ignore_package(self, pkg_name, pkg_version, allow_missing=False): + def ignore_package(self, pkg_name: str, pkg_version: Version, allow_missing=False) -> int: # find package, even if already ignored if not allow_missing: repo_copy = self._copy( @@ -688,7 +695,7 @@ def ignore_package(self, pkg_name, pkg_version, allow_missing=False): self._on_changed(pkg_name) return 1 - def unignore_package(self, pkg_name, pkg_version): + def unignore_package(self, pkg_name: str, pkg_version) -> int: # find and remove .ignore{ver} file if it exists ignore_file_was_removed = False filename = self.ignore_prefix + str(pkg_version) @@ -707,7 +714,7 @@ def unignore_package(self, pkg_name, pkg_version): else: return -1 - def remove_package(self, pkg_name, pkg_version): + def remove_package(self, pkg_name: str, pkg_version) -> bool: # ignore it first, so a partially deleted pkg is not visible i = self.ignore_package(pkg_name, pkg_version) if i == -1: @@ -735,7 +742,7 @@ def remove_package(self, pkg_name, pkg_version): return True - def remove_package_family(self, pkg_name, force=False): + def remove_package_family(self, pkg_name: str, force=False) -> bool: # get a non-cached copy and see if fam exists repo_copy = self._copy( disable_pkg_ignore=True, @@ -765,7 +772,7 @@ def remove_package_family(self, pkg_name, force=False): self._on_changed(pkg_name) return True - def remove_ignored_since(self, days, dry_run=False, verbose=False): + def remove_ignored_since(self, days, dry_run=False, verbose=False) -> int: now = int(time.time()) num_removed = 0 @@ -852,7 +859,7 @@ def file_lock_dir(self): return dirname - def pre_variant_install(self, variant_resource): + def pre_variant_install(self, variant_resource: VariantResourceHelper): if not variant_resource.version: return @@ -893,7 +900,7 @@ def on_variant_install_cancelled(self, variant_resource): family_path = os.path.join(self.location, variant_resource.name) self._delete_stale_build_tagfiles(family_path) - def install_variant(self, variant_resource, dry_run=False, overrides=None): + def install_variant(self, variant_resource: VariantResource, dry_run=False, overrides=None) -> VariantResource: overrides = overrides or {} # Name and version overrides are a special case - they change the @@ -940,7 +947,7 @@ def install_variant(self, variant_resource, dry_run=False, overrides=None): ) # install the variant - def _create_variant(): + def _create_variant() -> VariantResource: return self._create_variant( variant_resource, dry_run=dry_run, @@ -1003,7 +1010,7 @@ def _lock_package(self, package_name, package_version=None): except NotLocked: pass - def clear_caches(self): + def clear_caches(self) -> None: super(FileSystemPackageRepository, self).clear_caches() self.get_families.cache_clear() self.get_family.cache_clear() @@ -1018,7 +1025,7 @@ def clear_caches(self): # unfortunately we need to clear file cache across the board clear_file_caches() - def get_package_payload_path(self, package_name, package_version=None): + def get_package_payload_path(self, package_name: str, package_version=None) -> str: path = os.path.join(self.location, package_name) if package_version: @@ -1123,7 +1130,7 @@ def ignore_dir(name): def _is_valid_package_directory(self, path): return bool(self._get_file(path, "package")[0]) - def _get_families(self): + def _get_families(self) -> list[PackageFamilyResource]: families = [] for name, ext in self._get_family_dirs(): if ext is None: # is a directory @@ -1141,7 +1148,7 @@ def _get_families(self): return families - def _get_family(self, name): + def _get_family(self, name: str) -> PackageFamilyResource | None: is_valid_package_name(name, raise_error=True) if os.path.isdir(os.path.join(self.location, name)): # force case-sensitive match on pkg family dir, on case-insensitive platforms @@ -1171,13 +1178,13 @@ def _get_family(self, name): ) return None - def _get_packages(self, package_family_resource): + def _get_packages(self, package_family_resource: PackageFamilyResource) -> list[Package]: return [x for x in package_family_resource.iter_packages()] - def _get_variants(self, package_resource): + def _get_variants(self, package_resource: PackageResourceHelper) -> list[VariantResource]: return [x for x in package_resource.iter_variants()] - def _get_file(self, path, package_filename=None): + def _get_file(self, path, package_filename=None) -> tuple[str, FileFormat] | tuple[None, None]: if package_filename: package_filenames = [package_filename] else: @@ -1192,7 +1199,7 @@ def _get_file(self, path, package_filename=None): return filepath, format_ return None, None - def _create_family(self, name): + def _create_family(self, name: str): path = os.path.join(self.location, name) if not os.path.exists(path): os.makedirs(path) @@ -1200,7 +1207,7 @@ def _create_family(self, name): self._on_changed(name) return self.get_package_family(name) - def _create_variant(self, variant, dry_run=False, overrides=None): + def _create_variant(self, variant: VariantResource, dry_run=False, overrides=None) -> VariantResource | None: # special case overrides variant_name = overrides.get("name") or variant.name variant_version = overrides.get("version") or variant.version @@ -1220,7 +1227,7 @@ def _create_variant(self, variant, dry_run=False, overrides=None): % family.filepath) # find the package if it already exists - existing_package = None + existing_package: Package | None = None for package in self.iter_packages(family): if package.version == variant_version: @@ -1259,7 +1266,7 @@ def _create_variant(self, variant, dry_run=False, overrides=None): # converted to a Config object. We need it as the raw dict that you'd # see in a package.py. # - def _get_package_data(pkg): + def _get_package_data(pkg: PackageRepositoryResourceWrapper): data = pkg.validated_data() if hasattr(pkg, "_data"): raw_data = pkg._data @@ -1489,7 +1496,7 @@ def _remove_build_keys(obj): return new_variant - def _on_changed(self, pkg_name): + def _on_changed(self, pkg_name: str): """Called when a package is added/removed/changed. """ @@ -1504,7 +1511,7 @@ def _on_changed(self, pkg_name): # clear internal caches, otherwise change may not be visible self.clear_caches() - def _delete_stale_build_tagfiles(self, family_path): + def _delete_stale_build_tagfiles(self, family_path: str): now = time.time() for name in os.listdir(family_path): diff --git a/src/rezplugins/package_repository/memory.py b/src/rezplugins/package_repository/memory.py index 74d50731e..546712d5e 100644 --- a/src/rezplugins/package_repository/memory.py +++ b/src/rezplugins/package_repository/memory.py @@ -5,6 +5,8 @@ """ In-memory package repository """ +from __future__ import annotations + from rez.package_repository import PackageRepository from rez.package_resources import PackageFamilyResource, VariantResourceHelper, \ PackageResourceHelper, package_pod_schema @@ -12,6 +14,12 @@ from rez.utils.resources import ResourcePool, cached_property from rez.version import VersionedObject +from typing import Iterator, TYPE_CHECKING + +if TYPE_CHECKING: + from rez.packages import VariantResource + from rez.package_resources import PackageRepositoryResource + # This repository type is used when loading 'developer' packages (a package.yaml # or package.py in a developer's working directory), and when programmatically @@ -29,7 +37,7 @@ class MemoryPackageFamilyResource(PackageFamilyResource): def _uri(self): return "%s:%s" % (self.location, self.name) - def iter_packages(self): + def iter_packages(self) -> Iterator[MemoryPackageResource]: data = self._repository.data.get(self.name, {}) # check for unversioned package @@ -57,16 +65,16 @@ class MemoryPackageResource(PackageResourceHelper): repository_type = "memory" schema = package_pod_schema - def _uri(self): + def _uri(self) -> str: obj = VersionedObject.construct(self.name, self.version) return "%s:%s" % (self.location, str(obj)) @property - def base(self): + def base(self) -> str | None: return None # memory types do not have 'base' @cached_property - def parent(self): + def parent(self) -> PackageRepositoryResource: family = self._repository.get_resource( MemoryPackageFamilyResource.key, location=self.location, @@ -86,11 +94,11 @@ class MemoryVariantResource(VariantResourceHelper): key = "memory.variant" repository_type = "memory" - def _root(self): + def _root(self, ignore_shortlinks: bool = False) -> str | None: return None # memory types do not have 'root' @cached_property - def parent(self): + def parent(self) -> PackageRepositoryResource: package = self._repository.get_resource( MemoryPackageResource.key, location=self.location, @@ -135,7 +143,7 @@ def name(cls): return "memory" @classmethod - def create_repository(cls, repository_data): + def create_repository(cls, repository_data) -> MemoryPackageRepository: """Create a standalone, in-memory repository. Using this function bypasses the `package_repository_manager` singleton. @@ -155,7 +163,7 @@ def create_repository(cls, repository_data): repo.data = repository_data return repo - def __init__(self, location, resource_pool): + def __init__(self, location: str, resource_pool: ResourcePool): """Create an in-memory package repository. Args: @@ -167,7 +175,7 @@ def __init__(self, location, resource_pool): self.register_resource(MemoryPackageResource) self.register_resource(MemoryVariantResource) - def get_package_family(self, name): + def get_package_family(self, name: str) -> MemoryPackageFamilyResource | None: is_valid_package_name(name, raise_error=True) if name in self.data: family = self.get_resource( @@ -177,23 +185,23 @@ def get_package_family(self, name): return family return None - def iter_package_families(self): + def iter_package_families(self) -> Iterator[MemoryPackageFamilyResource | None]: for name in self.data.keys(): family = self.get_package_family(name) yield family - def iter_packages(self, package_family_resource): + def iter_packages(self, package_family_resource: MemoryPackageFamilyResource) -> Iterator[MemoryPackageResource]: for package in package_family_resource.iter_packages(): yield package - def iter_variants(self, package_resource): + def iter_variants(self, package_resource: PackageResourceHelper) -> Iterator[VariantResource]: for variant in package_resource.iter_variants(): yield variant - def get_parent_package_family(self, package_resource): + def get_parent_package_family(self, package_resource: PackageResourceHelper) -> PackageFamilyResource: return package_resource.parent - def get_parent_package(self, variant_resource): + def get_parent_package(self, variant_resource: VariantResource): return variant_resource.parent diff --git a/src/rezplugins/release_hook/amqp.py b/src/rezplugins/release_hook/amqp.py index cb3cefc3f..2d6c25e0f 100644 --- a/src/rezplugins/release_hook/amqp.py +++ b/src/rezplugins/release_hook/amqp.py @@ -5,10 +5,13 @@ """ Publishes a message to the broker. """ +from __future__ import annotations + from rez.release_hook import ReleaseHook from rez.utils.logging_ import print_error, print_debug from rez.utils.amqp import publish_message from rez.config import config +from typing import Any class AmqpReleaseHook(ReleaseHook): @@ -55,7 +58,7 @@ def post_release(self, user, install_path, variants, **kwargs): package = self.package # build the message dict - data = {} + data: dict[str, Any] = {} data["package"] = dict( name=package.name, version=str(package.version),