From fdccfcfc0bcb34e2401ad0748514acdc4b00bc58 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Fri, 29 Nov 2024 17:15:27 -0600 Subject: [PATCH] Addd support for autocommit --- src/snowflake/sqlalchemy/snowdialect.py | 40 +++++- tests/conftest.py | 27 ++++ tests/test_transactions.py | 157 ++++++++++++++++++++++++ 3 files changed, 223 insertions(+), 1 deletion(-) create mode 100644 tests/test_transactions.py diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index e6baadf7..e26e9812 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -5,8 +5,9 @@ import operator import re from collections import defaultdict +from enum import Enum from functools import reduce -from typing import Any +from typing import Any, Optional from urllib.parse import unquote_plus import sqlalchemy.types as sqltypes @@ -14,6 +15,7 @@ from sqlalchemy import exc as sa_exc from sqlalchemy import util as sa_util from sqlalchemy.engine import URL, default, reflection +from sqlalchemy.engine.interfaces import IsolationLevel from sqlalchemy.schema import Table from sqlalchemy.sql import text from sqlalchemy.sql.elements import quoted_name @@ -59,6 +61,11 @@ _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME = True +class SnowflakeIsolationLevel(Enum): + READ_COMMITTED = "READ COMMITTED" + AUTOCOMMIT = "AUTOCOMMIT" + + class SnowflakeDialect(default.DefaultDialect): name = DIALECT_NAME driver = "snowflake" @@ -139,6 +146,16 @@ class SnowflakeDialect(default.DefaultDialect): supports_identity_columns = True + def __init__( + self, + isolation_level: Optional[ + IsolationLevel + ] = SnowflakeIsolationLevel.READ_COMMITTED.value, + **kwargs: Any, + ): + super().__init__(isolation_level=isolation_level, **kwargs) + self._cache_column_metadata = False + @classmethod def dbapi(cls): return cls.import_dbapi() @@ -216,6 +233,27 @@ def has_table(self, connection, table_name, schema=None, **kw): """ return self._has_object(connection, "TABLE", table_name, schema) + def get_isolation_level_values(self, dbapi_connection): + return [ + SnowflakeIsolationLevel.READ_COMMITTED.value, + SnowflakeIsolationLevel.AUTOCOMMIT.value, + ] + + def do_rollback(self, dbapi_connection): + dbapi_connection.rollback() + + def do_commit(self, dbapi_connection): + dbapi_connection.commit() + + def get_default_isolation_level(self, dbapi_conn): + return SnowflakeIsolationLevel.READ_COMMITTED.value + + def set_isolation_level(self, dbapi_connection, level): + if level == SnowflakeIsolationLevel.AUTOCOMMIT.value: + dbapi_connection.autocommit(True) + else: + dbapi_connection.autocommit(False) + @reflection.cache def has_sequence(self, connection, sequence_name, schema=None, **kw): """ diff --git a/tests/conftest.py b/tests/conftest.py index a91521b9..f2045121 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ # from __future__ import annotations +import logging.handlers import os import sys import time @@ -194,6 +195,32 @@ def engine_testaccount(request): yield engine +@pytest.fixture() +def assert_text_in_buf(): + buf = logging.handlers.BufferingHandler(100) + for log in [ + logging.getLogger("sqlalchemy.engine"), + ]: + log.addHandler(buf) + + def go(expected, occurrences=1): + assert buf.buffer + buflines = [rec.getMessage() for rec in buf.buffer] + + ocurrences_found = buflines.count(expected) + assert occurrences == ocurrences_found, ( + f"Expected {occurrences} of {expected}, got {ocurrences_found} " + f"occurrences in {buflines}." + ) + buf.flush() + + yield go + for log in [ + logging.getLogger("sqlalchemy.engine"), + ]: + log.removeHandler(buf) + + @pytest.fixture() def engine_testaccount_with_numpy(request): url = url_factory(numpy=True) diff --git a/tests/test_transactions.py b/tests/test_transactions.py new file mode 100644 index 00000000..c163c2b7 --- /dev/null +++ b/tests/test_transactions.py @@ -0,0 +1,157 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from sqlalchemy import Column, Integer, MetaData, String, select, text + +from snowflake.sqlalchemy import SnowflakeTable + +CURRENT_TRANSACTION = text("SELECT CURRENT_TRANSACTION()") + + +def test_connect_read_commited(engine_testaccount, assert_text_in_buf): + metadata = MetaData() + table_name = "test_connect_read_commited" + + test_table_1 = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by=["id", text("id > 5")], + ) + + metadata.create_all(engine_testaccount) + try: + with engine_testaccount.connect().execution_options( + isolation_level="READ COMMITTED" + ) as connection: + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] == (None,), result + ins = test_table_1.insert().values(id=1, name="test") + connection.execute(ins) + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] != ( + None, + ), "AUTOCOMMIT DISABLED, transaction should be started" + + with engine_testaccount.connect() as conn: + s = select(test_table_1) + results = conn.execute(s).fetchall() + assert len(results) == 0, results # No insert commited + assert_text_in_buf("ROLLBACK", occurrences=1) + finally: + metadata.drop_all(engine_testaccount) + + +def test_begin_read_commited(engine_testaccount, assert_text_in_buf): + metadata = MetaData() + table_name = "test_begin_read_commited" + + test_table_1 = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by=["id", text("id > 5")], + ) + + metadata.create_all(engine_testaccount) + try: + with engine_testaccount.connect().execution_options( + isolation_level="READ COMMITTED" + ) as connection, connection.begin(): + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] == (None,), result + ins = test_table_1.insert().values(id=1, name="test") + connection.execute(ins) + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] != ( + None, + ), "AUTOCOMMIT DISABLED, transaction should be started" + + with engine_testaccount.connect() as conn: + s = select(test_table_1) + results = conn.execute(s).fetchall() + assert len(results) == 1, results # Insert commited + assert_text_in_buf("COMMIT", occurrences=2) + finally: + metadata.drop_all(engine_testaccount) + + +def test_connect_autocommit(engine_testaccount, assert_text_in_buf): + metadata = MetaData() + table_name = "test_connect_autocommit" + + test_table_1 = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by=["id", text("id > 5")], + ) + + metadata.create_all(engine_testaccount) + try: + with engine_testaccount.connect().execution_options( + isolation_level="AUTOCOMMIT" + ) as connection: + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] == (None,), result + ins = test_table_1.insert().values(id=1, name="test") + connection.execute(ins) + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] == ( + None, + ), "Autocommit enabled, transaction should not be started" + + with engine_testaccount.connect() as conn: + s = select(test_table_1) + results = conn.execute(s).fetchall() + assert len(results) == 1, results + assert_text_in_buf( + "ROLLBACK using DBAPI connection.rollback(), DBAPI should ignore due to autocommit mode", + occurrences=1, + ) + + finally: + metadata.drop_all(engine_testaccount) + + +def test_begin_autocommit(engine_testaccount, assert_text_in_buf): + metadata = MetaData() + table_name = "test_begin_autocommit" + + test_table_1 = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by=["id", text("id > 5")], + ) + + metadata.create_all(engine_testaccount) + try: + with engine_testaccount.connect().execution_options( + isolation_level="AUTOCOMMIT" + ) as connection, connection.begin(): + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] == (None,), result + ins = test_table_1.insert().values(id=1, name="test") + connection.execute(ins) + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] == ( + None, + ), "Autocommit enabled, transaction should not be started" + + with engine_testaccount.connect() as conn: + s = select(test_table_1) + results = conn.execute(s).fetchall() + assert len(results) == 1, results + assert_text_in_buf( + "COMMIT using DBAPI connection.commit(), DBAPI should ignore due to autocommit mode", + occurrences=1, + ) + + finally: + metadata.drop_all(engine_testaccount)