Skip to content

Commit

Permalink
Strict typing, move invoke/parsing into command_client
Browse files Browse the repository at this point in the history
  • Loading branch information
loopj committed Jan 7, 2025
1 parent 2946bf9 commit b04ea60
Show file tree
Hide file tree
Showing 12 changed files with 178 additions and 156 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ max-complexity = 15
convention = "google"

[tool.pyright]
typeCheckingMode = "standard"
typeCheckingMode = "strict"

[tool.bumpver]
current_version = "0.16.0"
Expand Down
10 changes: 10 additions & 0 deletions src/aiovantage/command_client/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ async with CommandClient("10.2.0.103") as client:
await client.command("LOAD", 118, 100)
```

### Invoke a method on an object

```python
from aiovantage.command_client import CommandClient

async with CommandClient("10.2.0.103") as client:
# Get the current level of load with id 118, as a Decimal
await client.invoke(118, "Load.GetLevel", as_type=Decimal)
```

### Subscribe to load events

```python
Expand Down
49 changes: 45 additions & 4 deletions src/aiovantage/command_client/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import asyncio
import logging
from dataclasses import dataclass
from decimal import Decimal
from ssl import SSLContext
from types import TracebackType
from typing import TypeVar

from typing_extensions import Self

Expand All @@ -19,7 +19,14 @@
ObjectOfflineError,
)

from .utils import encode_params, tokenize_response
from .utils import (
ParameterType,
encode_params,
parse_object_response,
tokenize_response,
)

T = TypeVar("T")


class CommandConnection(BaseConnection):
Expand Down Expand Up @@ -85,7 +92,7 @@ def close(self) -> None:
async def command(
self,
command: str,
*params: str | float | Decimal,
*params: ParameterType,
force_quotes: bool = False,
connection: CommandConnection | None = None,
) -> CommandResponse:
Expand All @@ -105,11 +112,45 @@ async def command(
if params:
request += f" {encode_params(*params, force_quotes=force_quotes)}"

# Send the request and parse the response
# Send the request
*data, return_line = await self.raw_request(request, connection=connection)

# Break the response into tokens
command, *args = tokenize_response(return_line)

# Parse the response
return CommandResponse(command[2:], args, data)

async def invoke(
self,
vid: int,
method: str,
*params: ParameterType,
as_type: type[T] | None = None,
) -> T | None:
"""Invoke a method on an object, and return the parsed response.
Args:
vid: The vid of the object to invoke the method on.
method: The method to invoke.
params: The parameters to send with the method.
as_type: The expected return type of the method.
Returns:
A parsed response, or None if no response was expected.
"""
# INVOKE <id> <Interface.Method> <arg1> <arg2> ...
# -> R:INVOKE <id> <result> <Interface.Method> <arg1> <arg2> ...

# Send the command
response = await self.command("INVOKE", vid, method, *params)

# Break the response into tokens
_id, result, _method, *args = response.args

# Parse the response
return parse_object_response(result, *args, as_type=as_type)

async def raw_request(
self, request: str, connection: CommandConnection | None = None
) -> list[str]:
Expand Down
82 changes: 64 additions & 18 deletions src/aiovantage/command_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import struct
from decimal import Decimal
from enum import IntEnum
from typing import Any, TypeVar, cast
from typing import Any, TypeVar, cast, get_type_hints

TOKEN_PATTERN = re.compile(r'"([^""]*(?:""[^""]*)*)"|(\{.*?\})|(\[.*?\])|(\S+)')

ParameterType = str | bool | int | float | bytearray | Decimal | IntEnum | dt.datetime

T = TypeVar("T")
ParameterType = str | bool | int | float | Decimal | bytearray


def tokenize_response(string: str) -> list[str]:
Expand Down Expand Up @@ -50,7 +51,6 @@ def parse_param(arg: str, klass: type[T]) -> T:
Raises:
ValueError: If the parameter is of an unsupported type.
"""
parsed: Any
if klass is int:
parsed = int(arg)
elif klass is bool:
Expand All @@ -63,6 +63,8 @@ def parse_param(arg: str, klass: type[T]) -> T:
parsed = dt.datetime.fromtimestamp(int(arg))
elif klass is Decimal:
parsed = parse_fixed_param(arg)
elif klass is float:
parsed = float(arg)
elif issubclass(klass, IntEnum):
# Support both integer and string values for IntEnum
parsed = klass(int(arg)) if arg.isdigit() else klass[arg]
Expand All @@ -72,7 +74,7 @@ def parse_param(arg: str, klass: type[T]) -> T:
return cast(T, parsed)


def encode_params(*params: Any, force_quotes: bool = False) -> str:
def encode_params(*params: ParameterType, force_quotes: bool = False) -> str:
"""Encode a list of parameters for sending to the Host Command service.
Converts all params to strings, wraps strings in double quotes, and escapes
Expand All @@ -89,20 +91,24 @@ def encode_params(*params: Any, force_quotes: bool = False) -> str:
TypeError: If a parameter is of an unsupported type.
"""

def encode(param: Any) -> str:
if isinstance(param, str):
return encode_string_param(param, force_quotes=force_quotes)
if isinstance(param, bool):
return "1" if param else "0"
if isinstance(param, IntEnum):
return str(param.value)
if isinstance(param, int):
return str(param)
if isinstance(param, float | Decimal):
return f"{param:.3f}"
if isinstance(param, bytearray):
return encode_byte_param(param)
raise TypeError(f"Unsupported type: {type(param)}")
def encode(param: ParameterType) -> str:
match param:
case str():
return encode_string_param(param, force_quotes=force_quotes)
case bool():
return "1" if param else "0"
case IntEnum():
return str(param.value)
case int():
return str(param)
case float():
return f"{param:.3f}"
case Decimal():
return f"{param:.3f}"
case bytearray():
return encode_byte_param(param)
case dt.datetime():
return str(int(param.timestamp()))

return " ".join(encode(param) for param in params)

Expand Down Expand Up @@ -190,3 +196,43 @@ def encode_byte_param(byte_array: bytearray) -> str:

# Join the tokens with commas and wrap in curly braces
return "{" + ",".join(tokens) + "}"


def parse_object_response(
result: str, *args: str, as_type: type[T] | None = None
) -> T | None:
"""Parse an object interface "INVOKE" response or status message.
Args:
result: The result of the command.
args: The arguments that were sent with the command.
as_type: The expected return type of the method.
Returns:
A response parsed into the expected type.
"""
# -> R:INVOKE <id> <result> <Interface.Method> <arg1> <arg2> ...
# -> EL: <id> <Interface.Method> <result> <arg1> <arg2> ...
# -> S:STATUS <id> <Interface.Method> <result> <arg1> <arg2> ...

# If no type is specified, return the raw result
if as_type is None:
return None

# Otherwise, parse the result into the expected type
if type_hints := get_type_hints(as_type):
# Some methods return multiple values, in the result and in the arguments
# To support this, if the signature is an object with type hints, we'll assume
# we are packing the values into the attributes of the object
# The "result" is packed into the first argument, followed by the rest of the arguments
parsed_values: list[Any] = []
for arg, klass in zip([result, *args], type_hints.values(), strict=True):
parsed_values.append(parse_param(arg, klass))

parsed_response = as_type(*parsed_values)
else:
# Otherwise, parse a single return value
parsed_response = parse_param(result, as_type)

# Return the parsed result
return parsed_response
4 changes: 3 additions & 1 deletion src/aiovantage/config_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,9 @@ async def request(

# Render the method object to XML with xsdata and send the request
response = await self.raw_request(
method.interface, self._serializer.render(method), connection
method.interface,
self._serializer.render(method), # type: ignore
connection,
)

# Parse the XML doc
Expand Down
4 changes: 2 additions & 2 deletions src/aiovantage/controllers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, vantage: "Vantage") -> None:
self._initialized = False
self._lock = asyncio.Lock()

QuerySet.__init__(self, self._items, self._lazy_initialize)
super().__init__(self._items, self._lazy_initialize)

self.__post_init__()

Expand Down Expand Up @@ -148,7 +148,7 @@ async def initialize(self, *, fetch_state: bool = True) -> None:
obj.id,
{
field.name: getattr(obj, field.name)
for field in fields(type(obj))
for field in fields(type(obj)) # type: ignore
if field.name != "m_time"
and field.metadata.get("type") != "Ignore"
},
Expand Down
2 changes: 1 addition & 1 deletion src/aiovantage/controllers/masters.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def handle_interface_status(
return

state = {
"m_time": ObjectInterface.parse_response(method, result, *args),
"m_time": ObjectInterface.parse_status(method, result, *args),
}

self.update_state(vid, state)
6 changes: 3 additions & 3 deletions src/aiovantage/controllers/rgb_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def handle_interface_status(
state: dict[str, Any] = {}

if method == "Load.GetLevel":
state["level"] = LoadInterface.parse_response(method, result, *args)
state["level"] = LoadInterface.parse_status(method, result, *args)

elif method == "RGBLoad.GetHSL":
if color := self._parse_color_channel_response(vid, method, result, *args):
Expand All @@ -71,7 +71,7 @@ def handle_interface_status(
state["rgbw"] = color

elif method == "ColorTemperature.Get":
state["color_temp"] = ColorTemperatureInterface.parse_response(
state["color_temp"] = ColorTemperatureInterface.parse_status(
method, result, *args
)

Expand Down Expand Up @@ -106,7 +106,7 @@ def _parse_color_channel_response(
raise ValueError(f"Unsupported color channel method {method}")

# Parse the response
response = RGBLoadInterface.parse_response(method, result, *args)
response = RGBLoadInterface.parse_status(method, result, *args)

# Ignore updates for channels we don't care about
if response.channel < 0 or response.channel >= num_channels:
Expand Down
2 changes: 1 addition & 1 deletion src/aiovantage/controllers/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def handle_interface_status(
return

state = {
"running": TaskInterface.parse_response(method, result, *args),
"running": TaskInterface.parse_status(method, result, *args),
}

self.update_state(vid, state)
6 changes: 2 additions & 4 deletions src/aiovantage/controllers/thermostats.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,9 @@ def handle_interface_status(
"""Handle object interface status messages from the event stream."""
state: dict[str, Any] = {}
if method == "Thermostat.GetHoldMode":
state["hold_mode"] = ThermostatInterface.parse_response(
method, result, *args
)
state["hold_mode"] = ThermostatInterface.parse_status(method, result, *args)
elif method == "Thermostat.GetStatus":
state["status"] = ThermostatInterface.parse_response(method, result, *args)
state["status"] = ThermostatInterface.parse_status(method, result, *args)

self.update_state(vid, state)

Expand Down
Loading

0 comments on commit b04ea60

Please sign in to comment.