diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 71749830..6e339af2 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -719,6 +719,7 @@ def connect_to_snowflake( database: str, role: Union[str, None] = None, warehouse: Union[str, None] = None, + **kwargs ): try: snowflake = __import__("snowflake.connector") @@ -765,7 +766,8 @@ def connect_to_snowflake( password=password, account=account, database=database, - client_session_keep_alive=True + client_session_keep_alive=True, + **kwargs ) def run_sql_snowflake(sql: str) -> pd.DataFrame: @@ -791,13 +793,13 @@ def run_sql_snowflake(sql: str) -> pd.DataFrame: self.run_sql = run_sql_snowflake self.run_sql_is_set = True - def connect_to_sqlite(self, url: str): + def connect_to_sqlite(self, url: str, check_same_thread: bool = False, **kwargs): """ Connect to a SQLite database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] Args: url (str): The URL of the database to connect to. - + check_same_thread (str): Allow the connection may be accessed in multiple threads. Returns: None """ @@ -816,7 +818,11 @@ def connect_to_sqlite(self, url: str): url = path # Connect to the database - conn = sqlite3.connect(url, check_same_thread=False) + conn = sqlite3.connect( + url, + check_same_thread=check_same_thread, + **kwargs + ) def run_sql_sqlite(sql: str): return pd.read_sql_query(sql, conn) @@ -832,6 +838,7 @@ def connect_to_postgres( user: str = None, password: str = None, port: int = None, + **kwargs ): """ Connect to postgres using the psycopg2 connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] @@ -901,6 +908,7 @@ def connect_to_postgres( user=user, password=password, port=port, + **kwargs ) except psycopg2.Error as e: raise ValidationError(e) @@ -932,12 +940,13 @@ def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]: def connect_to_mysql( - self, - host: str = None, - dbname: str = None, - user: str = None, - password: str = None, - port: int = None, + self, + host: str = None, + dbname: str = None, + user: str = None, + password: str = None, + port: int = None, + **kwargs ): try: @@ -981,12 +990,15 @@ def connect_to_mysql( conn = None try: - conn = pymysql.connect(host=host, - user=user, - password=password, - database=dbname, - port=port, - cursorclass=pymysql.cursors.DictCursor) + conn = pymysql.connect( + host=host, + user=user, + password=password, + database=dbname, + port=port, + cursorclass=pymysql.cursors.DictCursor, + **kwargs + ) except pymysql.Error as e: raise ValidationError(e) @@ -1016,12 +1028,13 @@ def run_sql_mysql(sql: str) -> Union[pd.DataFrame, None]: self.run_sql = run_sql_mysql def connect_to_clickhouse( - self, - host: str = None, - dbname: str = None, - user: str = None, - password: str = None, - port: int = None, + self, + host: str = None, + dbname: str = None, + user: str = None, + password: str = None, + port: int = None, + **kwargs ): try: @@ -1071,6 +1084,7 @@ def connect_to_clickhouse( username=user, password=password, database=dbname, + **kwargs ) print(conn) except Exception as e: @@ -1093,10 +1107,11 @@ def run_sql_clickhouse(sql: str) -> Union[pd.DataFrame, None]: self.run_sql = run_sql_clickhouse def connect_to_oracle( - self, - user: str = None, - password: str = None, - dsn: str = None, + self, + user: str = None, + password: str = None, + dsn: str = None, + **kwargs ): """ @@ -1149,7 +1164,8 @@ def connect_to_oracle( user=user, password=password, dsn=dsn, - ) + **kwargs + ) except oracledb.Error as e: raise ValidationError(e) @@ -1181,7 +1197,12 @@ def run_sql_oracle(sql: str) -> Union[pd.DataFrame, None]: self.run_sql_is_set = True self.run_sql = run_sql_oracle - def connect_to_bigquery(self, cred_file_path: str = None, project_id: str = None): + def connect_to_bigquery( + self, + cred_file_path: str = None, + project_id: str = None, + **kwargs + ): """ Connect to gcs using the bigquery connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] **Example:** @@ -1243,7 +1264,11 @@ def connect_to_bigquery(self, cred_file_path: str = None, project_id: str = None ) try: - conn = bigquery.Client(project=project_id, credentials=credentials) + conn = bigquery.Client( + project=project_id, + credentials=credentials, + **kwargs + ) except: raise ImproperlyConfigured( "Could not connect to bigquery please correct credentials" @@ -1266,7 +1291,7 @@ def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]: self.run_sql_is_set = True self.run_sql = run_sql_bigquery - def connect_to_duckdb(self, url: str, init_sql: str = None): + def connect_to_duckdb(self, url: str, init_sql: str = None, **kwargs): """ Connect to a DuckDB database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] @@ -1304,7 +1329,7 @@ def connect_to_duckdb(self, url: str, init_sql: str = None): f.write(response.content) # Connect to the database - conn = duckdb.connect(path) + conn = duckdb.connect(path, **kwargs) if init_sql: conn.query(init_sql) @@ -1315,7 +1340,7 @@ def run_sql_duckdb(sql: str): self.run_sql = run_sql_duckdb self.run_sql_is_set = True - def connect_to_mssql(self, odbc_conn_str: str): + def connect_to_mssql(self, odbc_conn_str: str, **kwargs): """ Connect to a Microsoft SQL Server database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] @@ -1348,7 +1373,7 @@ def connect_to_mssql(self, odbc_conn_str: str): from sqlalchemy import create_engine - engine = create_engine(connection_url) + engine = create_engine(connection_url, **kwargs) def run_sql_mssql(sql: str): # Execute the SQL statement and return the result as a pandas DataFrame @@ -1363,16 +1388,17 @@ def run_sql_mssql(sql: str): self.run_sql = run_sql_mssql self.run_sql_is_set = True def connect_to_presto( - self, - host: str, - catalog: str = 'hive', - schema: str = 'default', - user: str = None, - password: str = None, - port: int = None, - combined_pem_path: str = None, - protocol: str = 'https', - requests_kwargs: dict = None + self, + host: str, + catalog: str = 'hive', + schema: str = 'default', + user: str = None, + password: str = None, + port: int = None, + combined_pem_path: str = None, + protocol: str = 'https', + requests_kwargs: dict = None, + **kwargs ): """ Connect to a Presto database using the specified parameters. @@ -1445,7 +1471,8 @@ def connect_to_presto( schema=schema, port=port, protocol=protocol, - requests_kwargs=requests_kwargs) + requests_kwargs=requests_kwargs, + **kwargs) except presto.Error as e: raise ValidationError(e) @@ -1478,13 +1505,14 @@ def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]: self.run_sql = run_sql_presto def connect_to_hive( - self, - host: str = None, - dbname: str = 'default', - user: str = None, - password: str = None, - port: int = None, - auth: str = 'CUSTOM' + self, + host: str = None, + dbname: str = 'default', + user: str = None, + password: str = None, + port: int = None, + auth: str = 'CUSTOM', + **kwargs ): """ Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]