Skip to content

Commit

Permalink
Support arrary column (#36)
Browse files Browse the repository at this point in the history
* Support arrary column
  • Loading branch information
koxudaxi authored Nov 8, 2019
1 parent 0299e24 commit a52bd21
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 29 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@ Also, the package includes DB API 2.0 Client and SQLAlchemy Dialects.

## Features
- A user-friendly client which supports SQLAlchemy models
- SQLAlchemy Dialects (experimental)
- SQLAlchemy Dialects
- DB API 2.0 compatible client [PEP 249](https://www.python.org/dev/peps/pep-0249/)

## Support Database Engines
- MySQL
- PostgreSQL

## What's AWS Aurora Serverless's Data API?
https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/data-api.html

Expand Down
91 changes: 73 additions & 18 deletions pydataapi/pydataapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
Expand All @@ -32,6 +33,20 @@
'compile_kwargs': {"literal_binds": True},
}

BOOLEAN_VALUE: str = 'booleanValue'
STRING_VALUE: str = 'stringValue'
LONG_VALUE: str = 'longValue'
DOUBLE_VALUE: str = 'doubleValue'
BLOB_VALUE: str = 'blobValue'
IS_NULL: str = 'isNull'
ARRAY_VALUE: str = 'arrayValue'
ARRAY_VALUES: str = 'arrayValues'
BOOLEAN_VALUES: str = 'booleanValues'
STRING_VALUES: str = 'stringValues'
LONG_VALUES: str = 'longValues'
DOUBLE_VALUES: str = 'doubleValues'
BLOB_VALUES: str = 'blobValues'


def generate_sql(query: Union[Query, Insert, Update, Delete, Select]) -> str:
if hasattr(query, 'statement'):
Expand Down Expand Up @@ -96,19 +111,51 @@ def create_process_result_value_function_list(
]


def convert_array_value(value: Union[List, Tuple]) -> Dict[str, Any]:
first_value: Any = value[0]
if isinstance(first_value, (list, tuple)):
return {
ARRAY_VALUE: {
ARRAY_VALUES: [
convert_array_value(nested_value) for nested_value in value
]
}
}

values_key: Optional[str] = None
if isinstance(first_value, bool):
values_key = BOOLEAN_VALUES
elif isinstance(first_value, str):
values_key = STRING_VALUES
elif isinstance(first_value, int):
values_key = LONG_VALUES
elif isinstance(first_value, float):
values_key = DOUBLE_VALUES
elif isinstance(first_value, bytes):
values_key = BLOB_VALUES
if values_key:
return {ARRAY_VALUE: {values_key: list(value)}}
raise Exception(f'unsupported array type {type(value[0])}]: {value} ')


def convert_value(value: Any) -> Dict[str, Any]:
if isinstance(value, bool):
return {'booleanValue': value}
return {BOOLEAN_VALUE: value}
elif isinstance(value, str):
return {'stringValue': value}
return {STRING_VALUE: value}
elif isinstance(value, int):
return {'longValue': value}
return {LONG_VALUE: value}
elif isinstance(value, float):
return {'doubleValue': value}
return {DOUBLE_VALUE: value}
elif isinstance(value, bytes):
return {'blobValue': value}
return {BLOB_VALUE: value}
elif value is None:
return {'isNull': True}
return {IS_NULL: True}
elif isinstance(value, (list, tuple)):
if not value:
return {IS_NULL: True}
return convert_array_value(value)
# TODO: support structValue
else:
raise Exception(f'unsupported type {type(value)}: {value} ')

Expand All @@ -121,6 +168,23 @@ def create_sql_parameters(
]


def _get_value_from_row(row: Dict[str, Any]) -> Any:
key = tuple(row.keys())[0]
if key == IS_NULL:
return None
value = row[key]
if key == ARRAY_VALUE:
array_key: str = tuple(value.keys())[0]
array_value: Union[List[Dict[str, Dict]], Dict[str, List]] = value[array_key]
if array_key == ARRAY_VALUES:
return [
tuple(nested_value[ARRAY_VALUE].values())[0] # type: ignore
for nested_value in array_value
]
return array_value
return value


T = TypeVar('T')


Expand All @@ -137,7 +201,7 @@ def __init__(self, generated_fields: List[Dict[str, Any]]):
def generated_fields(self) -> List:
if self._generated_fields is None:
self._generated_fields = [
list(f.values())[0] for f in self._generated_fields_raw
_get_value_from_row(f) for f in self._generated_fields_raw
]
return self._generated_fields

Expand Down Expand Up @@ -229,11 +293,7 @@ def __init__(
if process_result_value_function_list:
self._rows: Sequence[List] = [
[
process_result_value(
None
if tuple(column.keys())[0] == 'isNull'
else tuple(column.values())[0]
)
process_result_value(_get_value_from_row(column))
for column, process_result_value in zip(
row, process_result_value_function_list
)
Expand All @@ -242,12 +302,7 @@ def __init__(
]
else:
self._rows = [
[
None
if tuple(column.keys())[0] == 'isNull'
else tuple(column.values())[0]
for column in row
]
[_get_value_from_row(column) for column in row]
for row in response.get('records', []) # type: ignore
]
self._column_metadata: List[Dict[str, Any]] = response.get('columnMetadata', [])
Expand Down
2 changes: 1 addition & 1 deletion scripts/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
set -e

export AWS_DEFAULT_REGION=us-west-2
pytest --cov=pydataapi --ignore-glob=tests/integration/** tests
pytest --cov=pydataapi --ignore-glob=tests/integration/** --cov-report term-missing tests
pytest --docker-compose-no-build --use-running-containers --docker-compose=tests/integration/docker-compose.yml tests/integration/
96 changes: 87 additions & 9 deletions tests/pydataapi/test_pydataapi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Dict

import pytest
import sqlalchemy.types as types
from pydantic import BaseModel, ValidationError
Expand All @@ -8,6 +10,8 @@
Record,
Result,
UpdateResults,
_get_value_from_row,
convert_array_value,
convert_value,
create_sql_parameters,
generate_sql,
Expand Down Expand Up @@ -37,22 +41,69 @@ def mocked_client(mocker):
return mocker.patch('boto3.client')


def test_convert_value() -> None:
assert convert_value('str') == {'stringValue': 'str'}
assert convert_value(123), {'longValue': 123}
assert convert_value(1.23), {'doubleValue': 1.23}
assert convert_value(True), {'booleanValue': True}
assert convert_value(False), {'booleanValue': False}
assert convert_value(b'bytes'), {'blobValue': b'bytes'}
assert convert_value(None), {'isNull': True}
@pytest.mark.parametrize(
'input_value, expected',
[
('str', {'stringValue': 'str'}),
(123, {'longValue': 123}),
(1.23, {'doubleValue': 1.23}),
(True, {'booleanValue': True}),
(False, {'booleanValue': False}),
(b'bytes', {'blobValue': b'bytes'}),
(None, {'isNull': True}),
([], {'isNull': True}),
('str', {'stringValue': 'str'}),
([123, 456], {'arrayValue': {'longValues': [123, 456]}}),
([1.23, 4.56], {'arrayValue': {'doubleValues': [1.23, 4.56]}}),
([True, False], {'arrayValue': {'booleanValues': [True, False]}}),
([b'bytes', b'blob'], {'arrayValue': {'blobValues': [b'bytes', b'blob']}}),
],
)
def test_convert_value(input_value: Any, expected: Dict[str, Any]) -> None:
assert convert_value(input_value) == expected


def test_convert_value_fail() -> None:
class Dummy:
pass

with pytest.raises(Exception):
convert_value(Dummy())


@pytest.mark.parametrize(
'input_value, expected',
[
(['str', 'string'], {'arrayValue': {'stringValues': ['str', 'string']}}),
([123, 456], {'arrayValue': {'longValues': [123, 456]}}),
([1.23, 4.56], {'arrayValue': {'doubleValues': [1.23, 4.56]}}),
([True, False], {'arrayValue': {'booleanValues': [True, False]}}),
([b'bytes', b'blob'], {'arrayValue': {'blobValues': [b'bytes', b'blob']}}),
(
[[123, 456], [789]],
{
'arrayValue': {
'arrayValues': [
{'arrayValue': {'longValues': [123, 456]}},
{'arrayValue': {'longValues': [789]}},
]
}
},
),
],
)
def test_convert_array_value(input_value: Any, expected: Dict[str, Any]) -> None:
assert convert_array_value(input_value) == expected


def test_convert_arrary_value_fail() -> None:
class Dummy:
pass

with pytest.raises(Exception):
convert_array_value([Dummy()])


def test_generate_sql() -> None:
class Users(declarative_base()):
__tablename__ = 'users'
Expand All @@ -70,6 +121,28 @@ class Users(declarative_base()):
)


@pytest.mark.parametrize(
'input_value, expected',
[
({'arrayValue': {'stringValues': ['str', 'string']}}, ['str', 'string']),
({'longValue': 123}, 123),
(
{
'arrayValue': {
'arrayValues': [
{'arrayValue': {'longValues': [123, 456]}},
{'arrayValue': {'longValues': [789]}},
]
}
},
[[123, 456], [789]],
),
],
)
def test_get_value_from_row(input_value: Dict[str, Any], expected: Any) -> None:
assert _get_value_from_row(input_value) == expected


def test_create_parameters() -> None:
expected = [
{'name': 'int', 'value': {'longValue': 1}},
Expand Down Expand Up @@ -162,24 +235,29 @@ def test_result() -> None:
'records': [
[{'longValue': 1}, {'stringValue': 'dog'}],
[{'longValue': 2}, {'stringValue': 'cat'}],
[{'longValue': 3}, {'isNull': True}],
],
"columnMetadata": column_metadata,
}
)

assert result[0] == [1, 'dog']
assert result[1] == [2, 'cat']
dog, cat = result[0:2]
assert result[2] == [3, None]
dog, cat, none = result[0:3]
assert dog == [1, 'dog']
assert cat == [2, 'cat']
assert none == [3, None]
assert next(result) == Record([1, 'dog'], ['id', 'name'])
assert next(result) == Record([2, 'cat'], ['id', 'name'])
assert next(result) == Record([3, None], ['id', 'name'])
with pytest.raises(StopIteration):
next(result)

assert result.all() == [
Record([1, 'dog'], ['id', 'name']),
Record([2, 'cat'], ['id', 'name']),
Record([3, None], ['id', 'name']),
]
assert result.first() == Record([1, 'dog'], ['id', 'name'])
with pytest.raises(MultipleResultsFound):
Expand Down

0 comments on commit a52bd21

Please sign in to comment.