From 4ce7f22e9471461fc641201d2ca254d83a328f45 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Fri, 10 Jan 2025 12:16:52 -0600 Subject: [PATCH] Fix drop support for SA array --- src/snowflake/sqlalchemy/base.py | 3 +++ src/snowflake/sqlalchemy/custom_types.py | 2 +- .../__snapshots__/test_structured_datatypes.ambr | 3 +++ tests/test_structured_datatypes.py | 15 +++++++++++++++ 4 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 0226d37d..818791cd 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -1117,6 +1117,9 @@ def visit_MAP(self, type_, **kw): ) def visit_ARRAY(self, type_, **kw): + return "ARRAY" + + def visit_SNOWFLAKE_ARRAY(self, type_, **kw): if type_.is_semi_structured: return "ARRAY" not_null = f" {NOT_NULL}" if type_.not_null else "" diff --git a/src/snowflake/sqlalchemy/custom_types.py b/src/snowflake/sqlalchemy/custom_types.py index 11cd2eb8..c742b740 100644 --- a/src/snowflake/sqlalchemy/custom_types.py +++ b/src/snowflake/sqlalchemy/custom_types.py @@ -83,7 +83,7 @@ def __repr__(self): class ARRAY(StructuredType): - __visit_name__ = "ARRAY" + __visit_name__ = "SNOWFLAKE_ARRAY" def __init__( self, diff --git a/tests/__snapshots__/test_structured_datatypes.ambr b/tests/__snapshots__/test_structured_datatypes.ambr index 3dcedf7c..453d26e4 100644 --- a/tests/__snapshots__/test_structured_datatypes.ambr +++ b/tests/__snapshots__/test_structured_datatypes.ambr @@ -5,6 +5,9 @@ # name: test_compile_table_with_double_map 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname MAP(DECIMAL, MAP(DECIMAL, VARCHAR)), \tPRIMARY KEY ("Id"))' # --- +# name: test_compile_table_with_sqlalchemy_array + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname ARRAY, \tPRIMARY KEY ("Id"))' +# --- # name: test_compile_table_with_structured_data_type[structured_type0] 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tPRIMARY KEY ("Id"))' # --- diff --git a/tests/test_structured_datatypes.py b/tests/test_structured_datatypes.py index ce030bd2..fb73673b 100644 --- a/tests/test_structured_datatypes.py +++ b/tests/test_structured_datatypes.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # import pytest +import sqlalchemy as sa from sqlalchemy import ( Column, Integer, @@ -47,6 +48,20 @@ def test_compile_table_with_structured_data_type( assert sql_compiler(create_table) == snapshot +def test_compile_table_with_sqlalchemy_array(sql_compiler, snapshot): + metadata = MetaData() + user_table = Table( + "clustered_user", + metadata, + Column("Id", Integer, primary_key=True), + Column("name", sa.ARRAY(sa.String)), + ) + + create_table = CreateTable(user_table) + + assert sql_compiler(create_table) == snapshot + + @pytest.mark.requires_external_volume def test_insert_map(engine_testaccount, external_volume, base_location, snapshot): metadata = MetaData()