From a52bd216a2eb69583d9b05860009875d120ba7c0 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Sat, 9 Nov 2019 02:05:08 +0900 Subject: [PATCH] Support arrary column (#36) * Support arrary column --- README.md | 6 +- pydataapi/pydataapi.py | 91 +++++++++++++++++++++++------ scripts/test.sh | 2 +- tests/pydataapi/test_pydataapi.py | 96 ++++++++++++++++++++++++++++--- 4 files changed, 166 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 2c9cfa4..f64d7b8 100644 --- a/README.md +++ b/README.md @@ -11,9 +11,13 @@ Also, the package includes DB API 2.0 Client and SQLAlchemy Dialects. ## Features - A user-friendly client which supports SQLAlchemy models -- SQLAlchemy Dialects (experimental) +- SQLAlchemy Dialects - DB API 2.0 compatible client [PEP 249](https://www.python.org/dev/peps/pep-0249/) +## Support Database Engines +- MySQL +- PostgreSQL + ## What's AWS Aurora Serverless's Data API? https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/data-api.html diff --git a/pydataapi/pydataapi.py b/pydataapi/pydataapi.py index 808028e..036b03f 100644 --- a/pydataapi/pydataapi.py +++ b/pydataapi/pydataapi.py @@ -8,6 +8,7 @@ List, Optional, Sequence, + Tuple, Type, TypeVar, Union, @@ -32,6 +33,20 @@ 'compile_kwargs': {"literal_binds": True}, } +BOOLEAN_VALUE: str = 'booleanValue' +STRING_VALUE: str = 'stringValue' +LONG_VALUE: str = 'longValue' +DOUBLE_VALUE: str = 'doubleValue' +BLOB_VALUE: str = 'blobValue' +IS_NULL: str = 'isNull' +ARRAY_VALUE: str = 'arrayValue' +ARRAY_VALUES: str = 'arrayValues' +BOOLEAN_VALUES: str = 'booleanValues' +STRING_VALUES: str = 'stringValues' +LONG_VALUES: str = 'longValues' +DOUBLE_VALUES: str = 'doubleValues' +BLOB_VALUES: str = 'blobValues' + def generate_sql(query: Union[Query, Insert, Update, Delete, Select]) -> str: if hasattr(query, 'statement'): @@ -96,19 +111,51 @@ def create_process_result_value_function_list( ] +def convert_array_value(value: Union[List, Tuple]) -> Dict[str, Any]: + first_value: Any = value[0] + if isinstance(first_value, (list, tuple)): + return { + ARRAY_VALUE: { + ARRAY_VALUES: [ + convert_array_value(nested_value) for nested_value in value + ] + } + } + + values_key: Optional[str] = None + if isinstance(first_value, bool): + values_key = BOOLEAN_VALUES + elif isinstance(first_value, str): + values_key = STRING_VALUES + elif isinstance(first_value, int): + values_key = LONG_VALUES + elif isinstance(first_value, float): + values_key = DOUBLE_VALUES + elif isinstance(first_value, bytes): + values_key = BLOB_VALUES + if values_key: + return {ARRAY_VALUE: {values_key: list(value)}} + raise Exception(f'unsupported array type {type(value[0])}]: {value} ') + + def convert_value(value: Any) -> Dict[str, Any]: if isinstance(value, bool): - return {'booleanValue': value} + return {BOOLEAN_VALUE: value} elif isinstance(value, str): - return {'stringValue': value} + return {STRING_VALUE: value} elif isinstance(value, int): - return {'longValue': value} + return {LONG_VALUE: value} elif isinstance(value, float): - return {'doubleValue': value} + return {DOUBLE_VALUE: value} elif isinstance(value, bytes): - return {'blobValue': value} + return {BLOB_VALUE: value} elif value is None: - return {'isNull': True} + return {IS_NULL: True} + elif isinstance(value, (list, tuple)): + if not value: + return {IS_NULL: True} + return convert_array_value(value) + # TODO: support structValue else: raise Exception(f'unsupported type {type(value)}: {value} ') @@ -121,6 +168,23 @@ def create_sql_parameters( ] +def _get_value_from_row(row: Dict[str, Any]) -> Any: + key = tuple(row.keys())[0] + if key == IS_NULL: + return None + value = row[key] + if key == ARRAY_VALUE: + array_key: str = tuple(value.keys())[0] + array_value: Union[List[Dict[str, Dict]], Dict[str, List]] = value[array_key] + if array_key == ARRAY_VALUES: + return [ + tuple(nested_value[ARRAY_VALUE].values())[0] # type: ignore + for nested_value in array_value + ] + return array_value + return value + + T = TypeVar('T') @@ -137,7 +201,7 @@ def __init__(self, generated_fields: List[Dict[str, Any]]): def generated_fields(self) -> List: if self._generated_fields is None: self._generated_fields = [ - list(f.values())[0] for f in self._generated_fields_raw + _get_value_from_row(f) for f in self._generated_fields_raw ] return self._generated_fields @@ -229,11 +293,7 @@ def __init__( if process_result_value_function_list: self._rows: Sequence[List] = [ [ - process_result_value( - None - if tuple(column.keys())[0] == 'isNull' - else tuple(column.values())[0] - ) + process_result_value(_get_value_from_row(column)) for column, process_result_value in zip( row, process_result_value_function_list ) @@ -242,12 +302,7 @@ def __init__( ] else: self._rows = [ - [ - None - if tuple(column.keys())[0] == 'isNull' - else tuple(column.values())[0] - for column in row - ] + [_get_value_from_row(column) for column in row] for row in response.get('records', []) # type: ignore ] self._column_metadata: List[Dict[str, Any]] = response.get('columnMetadata', []) diff --git a/scripts/test.sh b/scripts/test.sh index 2c8e1e8..08ae020 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -2,5 +2,5 @@ set -e export AWS_DEFAULT_REGION=us-west-2 -pytest --cov=pydataapi --ignore-glob=tests/integration/** tests +pytest --cov=pydataapi --ignore-glob=tests/integration/** --cov-report term-missing tests pytest --docker-compose-no-build --use-running-containers --docker-compose=tests/integration/docker-compose.yml tests/integration/ \ No newline at end of file diff --git a/tests/pydataapi/test_pydataapi.py b/tests/pydataapi/test_pydataapi.py index b6d7b8e..c9505c8 100644 --- a/tests/pydataapi/test_pydataapi.py +++ b/tests/pydataapi/test_pydataapi.py @@ -1,3 +1,5 @@ +from typing import Any, Dict + import pytest import sqlalchemy.types as types from pydantic import BaseModel, ValidationError @@ -8,6 +10,8 @@ Record, Result, UpdateResults, + _get_value_from_row, + convert_array_value, convert_value, create_sql_parameters, generate_sql, @@ -37,15 +41,29 @@ def mocked_client(mocker): return mocker.patch('boto3.client') -def test_convert_value() -> None: - assert convert_value('str') == {'stringValue': 'str'} - assert convert_value(123), {'longValue': 123} - assert convert_value(1.23), {'doubleValue': 1.23} - assert convert_value(True), {'booleanValue': True} - assert convert_value(False), {'booleanValue': False} - assert convert_value(b'bytes'), {'blobValue': b'bytes'} - assert convert_value(None), {'isNull': True} +@pytest.mark.parametrize( + 'input_value, expected', + [ + ('str', {'stringValue': 'str'}), + (123, {'longValue': 123}), + (1.23, {'doubleValue': 1.23}), + (True, {'booleanValue': True}), + (False, {'booleanValue': False}), + (b'bytes', {'blobValue': b'bytes'}), + (None, {'isNull': True}), + ([], {'isNull': True}), + ('str', {'stringValue': 'str'}), + ([123, 456], {'arrayValue': {'longValues': [123, 456]}}), + ([1.23, 4.56], {'arrayValue': {'doubleValues': [1.23, 4.56]}}), + ([True, False], {'arrayValue': {'booleanValues': [True, False]}}), + ([b'bytes', b'blob'], {'arrayValue': {'blobValues': [b'bytes', b'blob']}}), + ], +) +def test_convert_value(input_value: Any, expected: Dict[str, Any]) -> None: + assert convert_value(input_value) == expected + +def test_convert_value_fail() -> None: class Dummy: pass @@ -53,6 +71,39 @@ class Dummy: convert_value(Dummy()) +@pytest.mark.parametrize( + 'input_value, expected', + [ + (['str', 'string'], {'arrayValue': {'stringValues': ['str', 'string']}}), + ([123, 456], {'arrayValue': {'longValues': [123, 456]}}), + ([1.23, 4.56], {'arrayValue': {'doubleValues': [1.23, 4.56]}}), + ([True, False], {'arrayValue': {'booleanValues': [True, False]}}), + ([b'bytes', b'blob'], {'arrayValue': {'blobValues': [b'bytes', b'blob']}}), + ( + [[123, 456], [789]], + { + 'arrayValue': { + 'arrayValues': [ + {'arrayValue': {'longValues': [123, 456]}}, + {'arrayValue': {'longValues': [789]}}, + ] + } + }, + ), + ], +) +def test_convert_array_value(input_value: Any, expected: Dict[str, Any]) -> None: + assert convert_array_value(input_value) == expected + + +def test_convert_arrary_value_fail() -> None: + class Dummy: + pass + + with pytest.raises(Exception): + convert_array_value([Dummy()]) + + def test_generate_sql() -> None: class Users(declarative_base()): __tablename__ = 'users' @@ -70,6 +121,28 @@ class Users(declarative_base()): ) +@pytest.mark.parametrize( + 'input_value, expected', + [ + ({'arrayValue': {'stringValues': ['str', 'string']}}, ['str', 'string']), + ({'longValue': 123}, 123), + ( + { + 'arrayValue': { + 'arrayValues': [ + {'arrayValue': {'longValues': [123, 456]}}, + {'arrayValue': {'longValues': [789]}}, + ] + } + }, + [[123, 456], [789]], + ), + ], +) +def test_get_value_from_row(input_value: Dict[str, Any], expected: Any) -> None: + assert _get_value_from_row(input_value) == expected + + def test_create_parameters() -> None: expected = [ {'name': 'int', 'value': {'longValue': 1}}, @@ -162,6 +235,7 @@ def test_result() -> None: 'records': [ [{'longValue': 1}, {'stringValue': 'dog'}], [{'longValue': 2}, {'stringValue': 'cat'}], + [{'longValue': 3}, {'isNull': True}], ], "columnMetadata": column_metadata, } @@ -169,17 +243,21 @@ def test_result() -> None: assert result[0] == [1, 'dog'] assert result[1] == [2, 'cat'] - dog, cat = result[0:2] + assert result[2] == [3, None] + dog, cat, none = result[0:3] assert dog == [1, 'dog'] assert cat == [2, 'cat'] + assert none == [3, None] assert next(result) == Record([1, 'dog'], ['id', 'name']) assert next(result) == Record([2, 'cat'], ['id', 'name']) + assert next(result) == Record([3, None], ['id', 'name']) with pytest.raises(StopIteration): next(result) assert result.all() == [ Record([1, 'dog'], ['id', 'name']), Record([2, 'cat'], ['id', 'name']), + Record([3, None], ['id', 'name']), ] assert result.first() == Record([1, 'dog'], ['id', 'name']) with pytest.raises(MultipleResultsFound):