Skip to content

Commit

Permalink
Use decorator to define object interface method mappings, auto infer …
Browse files Browse the repository at this point in the history
…types
  • Loading branch information
loopj committed Jan 6, 2025
1 parent f8e0011 commit e96e9f3
Show file tree
Hide file tree
Showing 15 changed files with 375 additions and 238 deletions.
15 changes: 7 additions & 8 deletions src/aiovantage/object_interfaces/anemo_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,17 @@

from decimal import Decimal

from .base import Interface
from .base import Interface, method


class AnemoSensorInterface(Interface):
"""Interface for querying and controlling anemo (wind) sensors."""

method_signatures = {
"AnemoSensor.GetSpeed": Decimal,
"AnemoSensor.GetSpeedHW": Decimal,
}

# Status properties
speed: Decimal | None = None # "AnemoSensor.GetSpeed"
# Properties
speed: Decimal | None = None

# Methods
@method("AnemoSensor.GetSpeed", property="speed")
async def get_speed(self) -> Decimal:
"""Get the speed of an anemo sensor, using cached value if available.
Expand All @@ -27,6 +23,7 @@ async def get_speed(self) -> Decimal:
# -> R:INVOKE <id> <speed> AnemoSensor.GetSpeed
return await self.invoke("AnemoSensor.GetSpeed")

@method("AnemoSensor.GetSpeedHW")
async def get_speed_hw(self) -> Decimal:
"""Get the speed of an anemo sensor directly from the hardware.
Expand All @@ -37,6 +34,7 @@ async def get_speed_hw(self) -> Decimal:
# -> R:INVOKE <id> <speed> AnemoSensor.GetSpeedHW
return await self.invoke("AnemoSensor.GetSpeedHW")

@method("AnemoSensor.SetSpeed")
async def set_speed(self, speed: Decimal) -> None:
"""Set the speed of an anemo sensor.
Expand All @@ -47,6 +45,7 @@ async def set_speed(self, speed: Decimal) -> None:
# -> R:INVOKE <id> <rcode> AnemoSensor.SetSpeed <speed>
await self.invoke("AnemoSensor.SetSpeed", speed)

@method("AnemoSensor.SetSpeedSW")
async def set_speed_sw(self, speed: Decimal) -> None:
"""Set the cached speed of an anemo sensor.
Expand Down
78 changes: 70 additions & 8 deletions src/aiovantage/object_interfaces/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Base class for command client interfaces."""

from typing import Any, ClassVar, TypeVar, overload
from types import NoneType
from typing import Any, Protocol, TypeVar, get_type_hints, overload, runtime_checkable

from aiovantage.command_client import CommandClient
from aiovantage.command_client.utils import (
Expand All @@ -13,11 +14,72 @@
T = TypeVar("T")


class Interface:
"""Base class for command client object interfaces."""
@runtime_checkable
class MethodFunction(Protocol):
"""Type hint for a function tagged with a Vantage method."""

_method: str
_property: str | None

async def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Call the getter function."""


def method(method: str, *, property: str | None = None):
"""Decorator to map a python function to a Vantage method.
This is used to automatically keep track of expected return types for
Vantage method calls so we can parse responses correctly, both when
directly invoking methods and when receiving status messages.
Optionally, a property name can be associated with the method, which can
be used to update the state of the object when receiving status messages,
or to fetch the initial state of the object.
Args:
method: The vantage method name to associate with the function.
property: Optional property name to associate with the function.
"""

def decorator(func):
func._method = method
func._property = property
return func

return decorator

method_signatures: ClassVar[dict[str, type | None]] = {}
"""A mapping of method names to their return types."""

class InterfaceMeta(type):
"""Metaclass for object interfaces."""

method_signatures: dict[str, type]

def __new__(
cls: type["InterfaceMeta"],
name: str,
bases: tuple[type, ...],
dct: dict[str, Any],
):
"""Create a new object interface class."""
cls_obj = super().__new__(cls, name, bases, dct)
cls_obj.method_signatures = {}

# Include getter methods from base classes
for base in bases:
if hasattr(base, "method_signatures"):
cls_obj.method_signatures.update(base.method_signatures)

for attr in dct.values():
if isinstance(attr, MethodFunction):
# Collect method signatures
if method_signature := get_type_hints(attr).get("return"):
cls_obj.method_signatures[attr._method] = method_signature

return cls_obj


class Interface(metaclass=InterfaceMeta):
"""Base class for command client object interfaces."""

command_client: CommandClient | None = None
"""The command client to use for sending requests."""
Expand Down Expand Up @@ -111,7 +173,7 @@ def parse_response(
signature = as_type or cls._get_signature(method)

# Return early if this method has no return value
if signature is None:
if signature is NoneType:
return None

# Parse the response
Expand All @@ -133,10 +195,10 @@ def parse_response(
return parsed_response

@classmethod
def _get_signature(cls, method: str) -> type | None:
def _get_signature(cls, method: str) -> type:
# Get the signature of a method.
for klass in cls.__mro__:
if issubclass(klass, Interface) and method in klass.method_signatures:
return klass.method_signatures[method]

return None
raise ValueError(f"No signature found for method '{method}'.")
Loading

0 comments on commit e96e9f3

Please sign in to comment.