From 9d0f6aa581611064dc563a64dd9c65451f5c737e Mon Sep 17 00:00:00 2001 From: James Smith Date: Mon, 23 Dec 2024 19:47:36 -0800 Subject: [PATCH] Allowing providing a custom default ssl context function --- src/aiovantage/connection.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/aiovantage/connection.py b/src/aiovantage/connection.py index 8c84395..c71d3c2 100644 --- a/src/aiovantage/connection.py +++ b/src/aiovantage/connection.py @@ -1,14 +1,28 @@ """Wrapper for an asyncio connection to a Vantage controller.""" import asyncio +from collections.abc import Callable from ssl import CERT_NONE, SSLContext, create_default_context +from typing import ClassVar from .errors import ClientConnectionError, ClientTimeoutError +def _get_default_context() -> SSLContext: + """Create a default SSL context.""" + # We don't have a local issuer certificate to check against, and we'll be + # connecting to an IP address so we can't check the hostname + ctx = create_default_context() + ctx.check_hostname = False + ctx.verify_mode = CERT_NONE + return ctx + + class BaseConnection: """Wrapper for an asyncio connection to a Vantage controller.""" + ssl_context_factory: ClassVar[Callable[[], SSLContext]] = _get_default_context + default_port: int default_ssl_port: int buffer_limit: int = 2**16 @@ -30,11 +44,7 @@ def __init__( # Set up the SSL context self._ssl: SSLContext | None if ssl is True: - # We don't have a local issuer certificate to check against, and we'll be - # connecting to an IP address so we can't check the hostname - self._ssl = create_default_context() - self._ssl.check_hostname = False - self._ssl.verify_mode = CERT_NONE + self._ssl = self.ssl_context_factory() elif isinstance(ssl, SSLContext): self._ssl = ssl else: