Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve index reflection #556

Merged
merged 6 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ Source code is also available at:

# Release Notes

- (Unreleased)
- Fix quoting of `_` as column name
- Fix index columns was not being reflected
- Fix index reflection cache not working

- v1.7.1(December 02, 2024)
- Add support for partition by to copy into <location>
- Fix BOOLEAN type not found in snowdialect

- v1.7.0(November 21, 2024)

- Fixed quoting of `_` as column name
- Add support for dynamic tables and required options
- Add support for hybrid tables
- Fixed SAWarning when registering functions with existing name in default namespace
Expand Down
16 changes: 16 additions & 0 deletions src/snowflake/sqlalchemy/parser/custom_type_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
from typing import List

import sqlalchemy.types as sqltypes
from sqlalchemy.sql.type_api import TypeEngine
Expand Down Expand Up @@ -107,6 +108,21 @@ def extract_parameters(text: str) -> list:
return output_parameters


def parse_index_columns(columns: str) -> List[str]:
"""
Parses a string with a list of columns for an index.

:param columns: A string with a list of columns for an index, which may include parentheses.
:param compiler: A SQLAlchemy compiler.

:return: A list of columns as strings.

:example:
For input `"[A, B, C]"`, the output is `['A', 'B', 'C']`.
"""
return [column.strip() for column in columns.strip("[]").split(",")]


def parse_type(type_text: str) -> TypeEngine:
"""
Parses a type definition string and returns the corresponding SQLAlchemy type.
Expand Down
184 changes: 84 additions & 100 deletions src/snowflake/sqlalchemy/snowdialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
from collections import defaultdict
from functools import reduce
from typing import Any
from typing import Any, Collection, Optional
from urllib.parse import unquote_plus

import sqlalchemy.types as sqltypes
Expand Down Expand Up @@ -41,7 +41,7 @@
)
from .parser.custom_type_parser import * # noqa
from .parser.custom_type_parser import _CUSTOM_DECIMAL # noqa
from .parser.custom_type_parser import ischema_names, parse_type
from .parser.custom_type_parser import ischema_names, parse_index_columns, parse_type
from .sql.custom_schema.custom_table_prefix import CustomTablePrefix
from .util import (
_update_connection_application_name,
Expand Down Expand Up @@ -674,27 +674,43 @@ def get_columns(self, connection, table_name, schema=None, **kw):
raise sa_exc.NoSuchTableError()
return schema_columns[normalized_table_name]

def get_prefixes_from_data(self, name_to_index_map, row, **kw):
prefixes_found = []
for valid_prefix in CustomTablePrefix:
key = f"is_{valid_prefix.name.lower()}"
if key in name_to_index_map and row[name_to_index_map[key]] == "Y":
prefixes_found.append(valid_prefix.name)
return prefixes_found

@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
def _get_schema_tables_info(self, connection, schema=None, **kw):
"""
Gets all table names.
Retrieves information about all tables in the specified schema.
"""

schema = schema or self.default_schema_name
current_schema = schema
if schema:
cursor = connection.execute(
text(
f"SHOW /* sqlalchemy:get_table_names */ TABLES IN {self._denormalize_quote_join(schema)}"
)
)
else:
cursor = connection.execute(
text("SHOW /* sqlalchemy:get_table_names */ TABLES")
result = connection.execute(
text(
f"SHOW /* sqlalchemy:get_schema_tables_info */ TABLES IN SCHEMA {self._denormalize_quote_join(schema)}"
)
_, current_schema = self._current_database_schema(connection)
)

ret = [self.normalize_name(row[1]) for row in cursor]
name_to_index_map = self._map_name_to_idx(result)
tables = {}
for row in result.cursor.fetchall():
table_name = self.normalize_name(str(row[name_to_index_map["name"]]))
table_prefixes = self.get_prefixes_from_data(name_to_index_map, row)
tables[table_name] = {"prefixes": table_prefixes}

return tables

def get_table_names(self, connection, schema=None, **kw):
"""
Gets all table names.
"""
ret = self._get_schema_tables_info(
connection, schema, info_cache=kw.get("info_cache", None)
).keys()
return ret

@reflection.cache
Expand Down Expand Up @@ -748,17 +764,12 @@ def get_view_definition(self, connection, view_name, schema=None, **kw):

def get_temp_table_names(self, connection, schema=None, **kw):
schema = schema or self.default_schema_name
if schema:
cursor = connection.execute(
text(
f"SHOW /* sqlalchemy:get_temp_table_names */ TABLES \
IN {self._denormalize_quote_join(schema)}"
)
)
else:
cursor = connection.execute(
text("SHOW /* sqlalchemy:get_temp_table_names */ TABLES")
cursor = connection.execute(
text(
f"SHOW /* sqlalchemy:get_temp_table_names */ TABLES \
IN SCHEMA {self._denormalize_quote_join(schema)}"
)
)

ret = []
n2i = self.__class__._map_name_to_idx(cursor)
Expand Down Expand Up @@ -839,62 +850,79 @@ def get_table_comment(self, connection, table_name, schema=None, **kw):
)
}

def get_multi_indexes(
def get_table_names_with_prefix(
self,
connection,
*,
schema,
filter_names,
prefix,
**kw,
):
tables_data = self._get_schema_tables_info(connection, schema, **kw)
table_names = []
for table_name, tables_data_value in tables_data.items():
if prefix in tables_data_value["prefixes"]:
table_names.append(table_name)
return table_names

def get_multi_indexes(
self,
connection,
*,
schema: Optional[str] = None,
filter_names: Optional[Collection[str]] = None,
**kw,
):
"""
Gets the indexes definition
"""

table_prefixes = self.get_multi_prefixes(
connection, schema, filter_prefix=CustomTablePrefix.HYBRID.name
schema = schema or self.default_schema_name
hybrid_table_names = self.get_table_names_with_prefix(
connection,
schema=schema,
prefix=CustomTablePrefix.HYBRID.name,
info_cache=kw.get("info_cache", None),
)
if len(table_prefixes) == 0:
if len(hybrid_table_names) == 0:
return []
schema = schema or self.default_schema_name
if not schema:
result = connection.execute(
text("SHOW /* sqlalchemy:get_multi_indexes */ INDEXES")
)
else:
result = connection.execute(
text(
f"SHOW /* sqlalchemy:get_multi_indexes */ INDEXES IN SCHEMA {self._denormalize_quote_join(schema)}"
)

result = connection.execute(
text(
f"SHOW /* sqlalchemy:get_multi_indexes */ INDEXES IN SCHEMA {self._denormalize_quote_join(schema)}"
)
)

n2i = self.__class__._map_name_to_idx(result)
n2i = self._map_name_to_idx(result)
indexes = {}

for row in result.cursor.fetchall():
table = self.normalize_name(str(row[n2i["table"]]))
table_name = self.normalize_name(str(row[n2i["table"]]))
if (
row[n2i["name"]] == f'SYS_INDEX_{row[n2i["table"]]}_PRIMARY'
or table not in filter_names
or (schema, table) not in table_prefixes
or (
(schema, table) in table_prefixes
and CustomTablePrefix.HYBRID.name
not in table_prefixes[(schema, table)]
)
or table_name not in filter_names
or table_name not in hybrid_table_names
):
continue
index = {
"name": row[n2i["name"]],
"unique": row[n2i["is_unique"]] == "Y",
"column_names": row[n2i["columns"]],
"include_columns": row[n2i["included_columns"]],
"column_names": [
self.normalize_name(column)
for column in parse_index_columns(row[n2i["columns"]])
],
"include_columns": [
self.normalize_name(column)
for column in parse_index_columns(row[n2i["included_columns"]])
],
"dialect_options": {},
}
if (schema, table) in indexes:
indexes[(schema, table)] = indexes[(schema, table)].append(index)

if (schema, table_name) in indexes:
indexes[(schema, table_name)] = indexes[(schema, table_name)].append(
index
)
else:
indexes[(schema, table)] = [index]
indexes[(schema, table_name)] = [index]

return list(indexes.items())

Expand All @@ -906,50 +934,6 @@ def _value_or_default(self, data, table, schema):
else:
return []

def get_prefixes_from_data(self, n2i, row, **kw):
prefixes_found = []
for valid_prefix in CustomTablePrefix:
key = f"is_{valid_prefix.name.lower()}"
if key in n2i and row[n2i[key]] == "Y":
prefixes_found.append(valid_prefix.name)
return prefixes_found

@reflection.cache
def get_multi_prefixes(
self, connection, schema, table_name=None, filter_prefix=None, **kw
):
"""
Gets all table prefixes
"""
schema = schema or self.default_schema_name
filter = f"LIKE '{table_name}'" if table_name else ""
if schema:
result = connection.execute(
text(
f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES IN SCHEMA {schema}"
)
)
else:
result = connection.execute(
text(
f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES LIKE '{table_name}'"
)
)

n2i = self.__class__._map_name_to_idx(result)
tables_prefixes = {}
for row in result.cursor.fetchall():
table = self.normalize_name(str(row[n2i["name"]]))
table_prefixes = self.get_prefixes_from_data(n2i, row)
if filter_prefix and filter_prefix not in table_prefixes:
continue
if (schema, table) in tables_prefixes:
tables_prefixes[(schema, table)].append(table_prefixes)
else:
tables_prefixes[(schema, table)] = table_prefixes

return tables_prefixes

@reflection.cache
def get_indexes(self, connection, tablename, schema, **kw):
"""
Expand Down
27 changes: 27 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
from __future__ import annotations

import logging.handlers
import os
import sys
import time
Expand Down Expand Up @@ -194,6 +195,32 @@ def engine_testaccount(request):
yield engine


@pytest.fixture()
def assert_text_in_buf():
buf = logging.handlers.BufferingHandler(100)
for log in [
logging.getLogger("sqlalchemy.engine"),
]:
log.addHandler(buf)

def go(expected, occurrences=1):
assert buf.buffer
buflines = [rec.getMessage() for rec in buf.buffer]

ocurrences_found = buflines.count(expected)
assert occurrences == ocurrences_found, (
f"Expected {occurrences} of {expected}, got {ocurrences_found} "
f"occurrences in {buflines}."
)
buf.flush()

yield go
for log in [
logging.getLogger("sqlalchemy.engine"),
]:
log.removeHandler(buf)


@pytest.fixture()
def engine_testaccount_with_numpy(request):
url = url_factory(numpy=True)
Expand Down
Loading
Loading