From 4bf009e44c01183cd10403a1bc2cdcd2528882cc Mon Sep 17 00:00:00 2001 From: Paul Cruse III Date: Wed, 25 Oct 2023 22:06:34 -0500 Subject: [PATCH] making sure cors makes sense --- acai_aws/apigateway/requirements.py | 2 +- acai_aws/apigateway/response.py | 22 +++++++++++++++------- acai_aws/apigateway/router.py | 3 ++- tests/acai_aws/apigateway/test_response.py | 4 ++++ 4 files changed, 22 insertions(+), 9 deletions(-) diff --git a/acai_aws/apigateway/requirements.py b/acai_aws/apigateway/requirements.py index 85bb57f..16c4843 100644 --- a/acai_aws/apigateway/requirements.py +++ b/acai_aws/apigateway/requirements.py @@ -8,7 +8,7 @@ def requirements(**kwargs): def decorator_func(func): def raise_timeout(*_): - raise ApiTimeOutException + raise ApiTimeOutException() def start_timeout(timeout=None): if kwargs.get('timeout') is not None or timeout is not None: diff --git a/acai_aws/apigateway/response.py b/acai_aws/apigateway/response.py index 1777cd7..b99acfd 100644 --- a/acai_aws/apigateway/response.py +++ b/acai_aws/apigateway/response.py @@ -7,10 +7,10 @@ class Response: - def __init__(self): + def __init__(self, **kwargs): self.__code = 200 self.__is_json = True - self.__open_cors = True + self.__cors = kwargs.get('cors', True) self.__base64_encoded = False self.__compress = False self.__content_type = '' @@ -19,8 +19,8 @@ def __init__(self): @property def headers(self): - if self.open_cors: - self.__set_open_cors() + if self.cors: + self.__set_cors() return self.__headers @headers.setter @@ -38,13 +38,21 @@ def content_type(self): def content_type(self, content_type ): self.__content_type = content_type + @property + def cors(self): + return self.__cors + + @cors.setter + def cors(self, access): + self.__cors = access + @property def open_cors(self): - return self.__open_cors + return self.__cors @open_cors.setter def open_cors(self, access): - self.__open_cors = access + self.__cors = access @property def base64_encoded(self): @@ -123,7 +131,7 @@ def __compress_body(self, body): file.write(body.encode('utf-8')) return base64.b64encode(bytes_io.getvalue()).decode('ascii') - def __set_open_cors(self): + def __set_cors(self): self.__headers['Access-Control-Allow-Origin'] = '*' self.__headers['Access-Control-Allow-Headers'] = '*' diff --git a/acai_aws/apigateway/router.py b/acai_aws/apigateway/router.py index dc8cd10..1f75611 100644 --- a/acai_aws/apigateway/router.py +++ b/acai_aws/apigateway/router.py @@ -18,6 +18,7 @@ def __init__(self, **kwargs): self.__with_auth = kwargs.get('with_auth') self.__on_error = kwargs.get('on_error') self.__on_timeout = kwargs.get('on_timeout') + self.__cors = kwargs.get('cors', True) self.__timeout = kwargs.get('timeout', None) self.__output_error = kwargs.get('output_error', False) self.__verbose = kwargs.get('verbose_logging', False) @@ -32,7 +33,7 @@ def auto_load(self): def route(self, event, context): request = Request(event, context, self.__timeout) - response = Response() + response = Response(cors=self.__cors) try: self.__log_verbose(title='request-received', log={'request': request}) self.__run_route_procedure(request, response) diff --git a/tests/acai_aws/apigateway/test_response.py b/tests/acai_aws/apigateway/test_response.py index 210e131..e97ca2a 100644 --- a/tests/acai_aws/apigateway/test_response.py +++ b/tests/acai_aws/apigateway/test_response.py @@ -35,6 +35,10 @@ def test_default_headers(self): ) def test_closed_cors_headers(self): + self.response.cors = False + self.assertDictEqual(self.response.headers, {}) + + def test_closed_open_cors_headers(self): self.response.open_cors = False self.assertDictEqual(self.response.headers, {})