diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 82ddebc9..bbb33fb8 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -11,6 +11,7 @@ Source code is also available at: - (Unreleased) - Add support for partition by to copy into + - Fix BOOLEAN type not found in snowdialect - v1.7.0(November 22, 2024) diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index f9e2e4c8..e6baadf7 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -39,6 +39,8 @@ _CUSTOM_Float, _CUSTOM_Time, ) +from .parser.custom_type_parser import * # noqa +from .parser.custom_type_parser import _CUSTOM_DECIMAL # noqa from .parser.custom_type_parser import ischema_names, parse_type from .sql.custom_schema.custom_table_prefix import CustomTablePrefix from .util import ( diff --git a/tests/test_imports.py b/tests/test_imports.py new file mode 100644 index 00000000..0cfe5931 --- /dev/null +++ b/tests/test_imports.py @@ -0,0 +1,64 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import importlib +import inspect + +import pytest + + +def get_classes_from_module(module_name): + """Returns a set of class names from a given module.""" + try: + module = importlib.import_module(module_name) + members = inspect.getmembers(module) + return {name for name, obj in members if inspect.isclass(obj)} + + except ImportError: + print(f"Module '{module_name}' could not be imported.") + return set() + + +def test_types_in_snowdialect(): + classes_a = get_classes_from_module( + "snowflake.sqlalchemy.parser.custom_type_parser" + ) + classes_b = get_classes_from_module("snowflake.sqlalchemy.snowdialect") + assert classes_a.issubset(classes_b), str(classes_a - classes_b) + + +@pytest.mark.parametrize( + "type_class_name", + [ + "BIGINT", + "BINARY", + "BOOLEAN", + "CHAR", + "DATE", + "DATETIME", + "DECIMAL", + "FLOAT", + "INTEGER", + "REAL", + "SMALLINT", + "TIME", + "TIMESTAMP", + "VARCHAR", + "NullType", + "_CUSTOM_DECIMAL", + "ARRAY", + "DOUBLE", + "GEOGRAPHY", + "GEOMETRY", + "MAP", + "OBJECT", + "TIMESTAMP_LTZ", + "TIMESTAMP_NTZ", + "TIMESTAMP_TZ", + "VARIANT", + ], +) +def test_snowflake_data_types_instance(type_class_name): + classes_b = get_classes_from_module("snowflake.sqlalchemy.snowdialect") + assert type_class_name in classes_b, type_class_name