From 1e10250e1dee4b6d07ac3ead1ed37fb0b331a510 Mon Sep 17 00:00:00 2001 From: Mikhail Koviazin Date: Tue, 26 Nov 2024 11:00:02 +0100 Subject: [PATCH] tests: _assert_connect to support min/max SSL version `ssl.wrap_socket` supported only `ssl_version` and hence this was what `_assert_connect` used. `SSLContext` OTOH supports settings explicitly minimum and maximum supported SSL versions. Use that to properly fix SSL tests. Signed-off-by: Mikhail Koviazin --- tests/test_asyncio/test_connect.py | 10 +++++----- tests/test_connect.py | 13 +++++++------ 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/tests/test_asyncio/test_connect.py b/tests/test_asyncio/test_connect.py index ac0465dd..dc92b2f1 100644 --- a/tests/test_asyncio/test_connect.py +++ b/tests/test_asyncio/test_connect.py @@ -125,7 +125,7 @@ async def test_tcp_ssl_version_mismatch(tcp_address): tcp_address, certfile=certfile, keyfile=keyfile, - ssl_version=ssl.TLSVersion.TLSv1_2, + maximum_ssl_version=ssl.TLSVersion.TLSv1_2, ) await conn.disconnect() @@ -135,7 +135,8 @@ async def _assert_connect( server_address, certfile=None, keyfile=None, - ssl_version=None, + minimum_ssl_version=ssl.TLSVersion.TLSv1_2, + maximum_ssl_version=ssl.TLSVersion.TLSv1_3, ): stop_event = asyncio.Event() finished = asyncio.Event() @@ -153,9 +154,8 @@ async def _handler(reader, writer): elif certfile: host, port = server_address context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - if ssl_version is not None: - context.minimum_version = ssl_version - context.maximum_version = ssl_version + context.minimum_version = minimum_ssl_version + context.maximum_version = maximum_ssl_version context.load_cert_chain(certfile=certfile, keyfile=keyfile) server = await asyncio.start_server(_handler, host=host, port=port, ssl=context) else: diff --git a/tests/test_connect.py b/tests/test_connect.py index b5464412..cc580008 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -100,7 +100,6 @@ def test_tcp_ssl_tls12_custom_ciphers(tcp_address, ssl_ciphers): tcp_address, certfile=certfile, keyfile=keyfile, - ssl_version=ssl.TLSVersion.TLSv1_2, ) @@ -141,7 +140,7 @@ def test_tcp_ssl_version_mismatch(tcp_address): tcp_address, certfile=certfile, keyfile=keyfile, - ssl_version=ssl.TLSVersion.TLSv1_2, + maximum_ssl_version=ssl.TLSVersion.TLSv1_2, ) @@ -170,14 +169,16 @@ def __init__( *args, certfile=None, keyfile=None, - ssl_version=ssl.TLSVersion.TLSv1, + minimum_ssl_version=ssl.TLSVersion.TLSv1_2, + maximum_ssl_version=ssl.TLSVersion.TLSv1_3, **kw, ) -> None: self._ready_event = threading.Event() self._stop_requested = False self._certfile = certfile self._keyfile = keyfile - self._ssl_version = ssl_version + self._minimum_ssl_version = minimum_ssl_version + self._maximum_ssl_version = maximum_ssl_version super().__init__(*args, **kw) def service_actions(self): @@ -199,8 +200,8 @@ def get_request(self): newsocket, fromaddr = self.socket.accept() sslctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) sslctx.load_cert_chain(self._certfile, self._keyfile) - sslctx.minimum_version = self._ssl_version - sslctx.maximum_version = self._ssl_version + sslctx.minimum_version = self._minimum_ssl_version + sslctx.maximum_version = self._maximum_ssl_version connstream = sslctx.wrap_socket( newsocket, server_side=True,