From 4a1612bfc66fac36a4b1c43ab7defa23fb6e9d4b Mon Sep 17 00:00:00 2001 From: Avasam Date: Mon, 28 Oct 2024 14:24:39 -0400 Subject: [PATCH] Make reinitialize_command's return type Generic when "command" argument is a Command --- distutils/cmd.py | 17 ++++++++++++++++- distutils/dist.py | 20 +++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/distutils/cmd.py b/distutils/cmd.py index 2bb97956..6ffe7bd4 100644 --- a/distutils/cmd.py +++ b/distutils/cmd.py @@ -4,15 +4,20 @@ in the distutils.command package. """ +from __future__ import annotations + import logging import os import re import sys +from typing import TypeVar, overload from . import _modified, archive_util, dir_util, file_util, util from ._log import log from .errors import DistutilsOptionError +_CommandT = TypeVar("_CommandT", bound="Command") + class Command: """Abstract base class for defining command classes, the "worker bees" @@ -305,7 +310,17 @@ def get_finalized_command(self, command, create=True): # XXX rename to 'get_reinitialized_command()'? (should do the # same in dist.py, if so) - def reinitialize_command(self, command, reinit_subcommands=False): + @overload + def reinitialize_command( + self, command: str, reinit_subcommands: bool = False + ) -> Command: ... + @overload + def reinitialize_command( + self, command: _CommandT, reinit_subcommands: bool = False + ) -> _CommandT: ... + def reinitialize_command( + self, command: str | Command, reinit_subcommands=False + ) -> Command: return self.distribution.reinitialize_command(command, reinit_subcommands) def run_command(self, command): diff --git a/distutils/dist.py b/distutils/dist.py index 154301ba..86527b15 100644 --- a/distutils/dist.py +++ b/distutils/dist.py @@ -4,6 +4,8 @@ being built/installed/distributed. """ +from __future__ import annotations + import contextlib import logging import os @@ -13,6 +15,7 @@ import warnings from collections.abc import Iterable from email import message_from_file +from typing import TYPE_CHECKING, TypeVar, overload from packaging.utils import canonicalize_name, canonicalize_version @@ -27,6 +30,11 @@ from .fancy_getopt import FancyGetopt, translate_longopt from .util import check_environ, rfc822_escape, strtobool +if TYPE_CHECKING: + from .cmd import Command + +_CommandT = TypeVar("_CommandT", bound="Command") + # Regex to define acceptable Distutils command names. This is not *quite* # the same as a Python NAME -- I don't allow leading underscores. The fact # that they're very similar is no coincidence; the default naming scheme is @@ -900,7 +908,17 @@ def _set_command_options(self, command_obj, option_dict=None): # noqa: C901 except ValueError as msg: raise DistutilsOptionError(msg) - def reinitialize_command(self, command, reinit_subcommands=False): + @overload + def reinitialize_command( + self, command: str, reinit_subcommands: bool = False + ) -> Command: ... + @overload + def reinitialize_command( + self, command: _CommandT, reinit_subcommands: bool = False + ) -> _CommandT: ... + def reinitialize_command( + self, command: str | Command, reinit_subcommands=False + ) -> Command: """Reinitializes a command to the state it was in when first returned by 'get_command_obj()': ie., initialized but not yet finalized. This provides the opportunity to sneak option