diff --git a/docs/notebooks/vn-ask.md b/docs/notebooks/vn-ask.md index 1249c16d..f6f4bd90 100644 --- a/docs/notebooks/vn-ask.md +++ b/docs/notebooks/vn-ask.md @@ -1,10 +1,10 @@ -![Vanna AI](https://img.vanna.ai/vanna-full.svg) +![Vanna AI](https://img.vanna.ai/vanna-ask.svg) -This notebook will help you unleash the full potential of AI-powered data analysis at your organization. We'll go through how to "bulk train" Vanna and generate SQL, tables, charts, and explanations, all with minimal code and effort. For more about Vanna, see our [intro blog post](https://medium.com/vanna-ai/intro-to-vanna-a-python-based-ai-sql-co-pilot-218c25b19c6a). +The following notebook goes through the process of asking questions from your data using Vanna AI. Here we use a demo model that is pre-trained on the [TPC-H dataset](https://docs.snowflake.com/en/user-guide/sample-data-tpch.html) that is available in Snowflake. -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vanna-ai/vanna-py/blob/main/notebooks/vn-full.ipynb) +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vanna-ai/vanna-py/blob/main/notebooks/vn-ask.ipynb) -[![Open in GitHub](https://img.vanna.ai/github.svg)](https://github.com/vanna-ai/vanna-py/blob/main/notebooks/vn-full.ipynb) +[![Open in GitHub](https://img.vanna.ai/github.svg)](https://github.com/vanna-ai/vanna-py/blob/main/notebooks/vn-ask.ipynb) # Install Vanna First we install Vanna from [PyPI](https://pypi.org/project/vanna/) and import it. @@ -20,7 +20,6 @@ Here, we'll also install the Snowflake connector. If you're using a different da ```python import vanna as vn import snowflake.connector -import pandas as pd ``` # Login @@ -37,7 +36,7 @@ You need to choose a globally unique model name. Try using your company name or ```python -vn.set_model('my-model') # Enter your dataset name here. This is a globally unique identifier for your dataset. +vn.set_model('tpc') # Enter your model name here. This is a globally unique identifier for your model. ``` # Set Database Connection @@ -64,20 +63,105 @@ vn.ask("What are the top 10 customers by sales?") GROUP BY customer_name ORDER BY total_sales desc limit 10; -![plot1](plot1.png) - AI-generated follow-up questions: - What are the countries of the top 10 customers by sales? - How many orders did each of the top 10 customers place? - What is the average sales amount per customer in the top 10? - Can you provide a breakdown of the sales by country for the top 10 customers? - Who are the top 10 customers in terms of returned parts gross value? - What are the total sales for each customer in the top 3? - Can you provide a breakdown of the sales by region for the top customers? - How many customers are there in each country? - What is the total revenue for the top 5 countries? - Can you provide a breakdown of the sales by customer for the top 5 countries? +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
CUSTOMER_NAMETOTAL_SALES
0Customer#0001435006757566.0218
1Customer#0000952576294115.3340
2Customer#0000871156184649.5176
3Customer#0001311136080943.8305
4Customer#0001343806075141.9635
5Customer#0001038346059770.3232
6Customer#0000696826057779.0348
7Customer#0001020226039653.6335
8Customer#0000985876027021.5855
9Customer#0000646605905659.6159
+
+ + + + +![png](vn-ask_files/vn-ask_10_2.png) + + + + +AI-generated follow-up questions: + +* What is the country name for each of the top 10 customers by sales? +* How many orders does each of the top 10 customers by sales have? +* What is the total revenue for each of the top 10 customers by sales? +* What are the customer names and total sales for customers in the United States? +* Which customers in Africa have returned the most parts with a gross value? +* What are the total sales for the top 3 customers? +* What are the customer names and total sales for the top 5 customers? +* What are the total sales for customers in Europe? +* How many customers are there in each country? + @@ -149,19 +233,26 @@ vn.ask("Which 5 countries have the highest sales?") -![plot2](plot2.png) + +![png](vn-ask_files/vn-ask_11_2.png) + + + + +AI-generated follow-up questions: + +* What are the total sales for each country in descending order? +* Which country has the highest number of customers? +* What are the total sales for each customer in descending order? +* Which customers in the United States have the highest total sales? +* What are the total sales and number of orders for each customer in each country? +* What are the total sales for customers in Europe? +* What are the top 10 countries with the highest total order amount? +* Which country has the highest number of failed orders? +* Which customers have the highest total sales? +* - AI-generated follow-up questions: - What are the total sales for each country? - Which country has the highest number of customers? - What are the total sales for each customer? - What are the top 3 customers with the highest sales? - What is the total revenue for each customer and country? - What are the total sales for each customer in Europe? - What are the top 10 countries with the highest total order amount? - Which country has the highest number of failed orders? - What are the top 3 customers with the highest sales? @@ -264,20 +355,26 @@ vn.ask("Who are the top 2 biggest customers in each region?") -![plot3](plot3.png) + +![png](vn-ask_files/vn-ask_12_2.png) + + + + +AI-generated follow-up questions: + +* - What are the total sales for each customer in the Asia region? +* - How many orders does each customer in the Americas region have? +* - Who are the top 5 customers with the highest total sales? +* - What is the total revenue for each customer in the Europe region? +* - Can you provide a breakdown of the number of customers in each country? +* - Which customers in the United States have the highest total sales? +* - What are the total sales for each customer in the Asia region? +* - What are the top 10 customers with the highest returned parts gross value in Africa? +* - What are the top 3 customers with the highest total sales overall? +* - Can you provide a list of the first 10 customers in the database? - AI-generated follow-up questions: - - What are the total sales for each customer in Europe? - - What are the total sales for each customer in the United States? - - How many customers are there in each country? - - What is the total revenue for each customer in each country? - - Which customers have the highest total sales? - - Which customers have the highest number of orders? - - Which customers have the highest returned parts gross value in Africa? - - What are the total sales for the top 3 customers? - - What are the total sales for the top 10 customers? - - What is the total sales for each customer? # Run as a Web App diff --git a/docs/notebooks/vn-ask_files/vn-ask_10_2.png b/docs/notebooks/vn-ask_files/vn-ask_10_2.png new file mode 100644 index 00000000..0a4fb601 Binary files /dev/null and b/docs/notebooks/vn-ask_files/vn-ask_10_2.png differ diff --git a/docs/notebooks/vn-ask_files/vn-ask_11_2.png b/docs/notebooks/vn-ask_files/vn-ask_11_2.png new file mode 100644 index 00000000..1a9cd7fd Binary files /dev/null and b/docs/notebooks/vn-ask_files/vn-ask_11_2.png differ diff --git a/docs/notebooks/vn-ask_files/vn-ask_12_2.png b/docs/notebooks/vn-ask_files/vn-ask_12_2.png new file mode 100644 index 00000000..0210f1f9 Binary files /dev/null and b/docs/notebooks/vn-ask_files/vn-ask_12_2.png differ diff --git a/docs/notebooks/vn-train.md b/docs/notebooks/vn-train.md index 83d1d778..2cbb44f7 100644 --- a/docs/notebooks/vn-train.md +++ b/docs/notebooks/vn-train.md @@ -1,3 +1,269 @@ +![Vanna AI](https://img.vanna.ai/vanna-train.svg) + +The following notebook goes through the process of training Vanna. + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vanna-ai/vanna-py/blob/main/notebooks/vn-train.ipynb) + +[![Open in GitHub](https://img.vanna.ai/github.svg)](https://github.com/vanna-ai/vanna-py/blob/main/notebooks/vn-ask.ipynb) + +# Install Vanna +First we install Vanna from [PyPI](https://pypi.org/project/vanna/) and import it. +Here, we'll also install the Snowflake connector. If you're using a different database, you'll need to install the appropriate connector. + + +```python +%pip install vanna +%pip install snowflake-connector-python +``` + + ```python +import vanna as vn +import snowflake.connector +``` + +# Login +Creating a login and getting an API key is as easy as entering your email (after you run this cell) and entering the code we send to you. Check your Spam folder if you don't see the code. + + +```python +api_key = vn.get_api_key('my-email@example.com') +vn.set_api_key(api_key) +``` + +# Set your Model +You need to choose a globally unique model name. Try using your company name or another unique string. All data from models are isolated - there's no leakage. + + +```python +vn.set_model('my-model') # Enter your model name here. This is a globally unique identifier for your model. +``` + +# Automatic Training +If you'd like to use automatic training, the Vanna package can crawl your database to fetch metadata to train your model. You can put in your Snowflake credentials here. These details are only referenced within your notebook. These database credentials are never sent to Vanna's severs. + + +```python +vn.connect_to_snowflake(account='my-account', username='my-username', password='my-password', database='my-database') +``` + + +```python +training_plan = vn.get_training_plan_experimental(filter_databases=['SNOWFLAKE_SAMPLE_DATA'], filter_schemas=['TPCH_SF1']) +training_plan +``` + + Trying query history + Trying INFORMATION_SCHEMA.COLUMNS for SNOWFLAKE_SAMPLE_DATA + + + + + + Train on SQL: What are the top 10 customers ranked by total sales? + Train on SQL: What are the top 10 customers in terms of total sales? + Train on SQL: What are the top two customers with the highest total sales for each region? + Train on SQL: What are the top 5 customers with the highest total sales? + Train on SQL: What is the total quantity of each product sold in each region, ordered by region name and total quantity in descending order? + Train on SQL: What is the number of orders for each week, starting from the most recent week? + Train on SQL: What countries are in the region 'EUROPE'? + Train on Information Schema: SNOWFLAKE_SAMPLE_DATA.TPCH_SF1 SUPPLIER + Train on Information Schema: SNOWFLAKE_SAMPLE_DATA.TPCH_SF1 LINEITEM + Train on Information Schema: SNOWFLAKE_SAMPLE_DATA.TPCH_SF1 CUSTOMER + Train on Information Schema: SNOWFLAKE_SAMPLE_DATA.TPCH_SF1 PARTSUPP + Train on Information Schema: SNOWFLAKE_SAMPLE_DATA.TPCH_SF1 PART + Train on Information Schema: SNOWFLAKE_SAMPLE_DATA.TPCH_SF1 ORDERS + Train on Information Schema: SNOWFLAKE_SAMPLE_DATA.TPCH_SF1 REGION + Train on Information Schema: SNOWFLAKE_SAMPLE_DATA.TPCH_SF1 NATION + + + +```python +vn.train(plan=training_plan) +``` + +# Train with DDL Statements +If you prefer to manually train, you do not need to connect to a database. You can use the train function with other parmaeters like ddl + + +```python +vn.train(ddl=""" + CREATE TABLE IF NOT EXISTS my-table ( + id INT PRIMARY KEY, + name VARCHAR(100), + age INT + ) +""") +``` + +# Train with Documentation +Sometimes you may want to add documentation about your business terminology or definitions. + + +```python +vn.train(documentation="Our business defines OTIF score as the percentage of orders that are delivered on time and in full") +``` + +# Train with SQL +You can also add SQL queries to your training data. This is useful if you have some queries already laying around. You can just copy and paste those from your editor to begin generating new SQL. + + +```python +vn.train(sql="SELECT * FROM my-table WHERE name = 'John Doe'") +``` + +# View Training Data +At any time you can see what training data is in your model + + +```python +vn.get_training_data() +``` + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
idtraining_data_typequestioncontent
015-docdocumentationNoneThis is a table in the PARTSUPP table.\n\nThe ...
111-docdocumentationNoneThis is a table in the CUSTOMER table.\n\nThe ...
214-docdocumentationNoneThis is a table in the ORDERS table.\n\nThe fo...
31244-sqlsqlWhat are the names of the top 10 customers?SELECT c.c_name as customer_name\nFROM snowf...
41242-sqlsqlWhat are the top 5 customers in terms of total...SELECT c.c_name AS customer_name, SUM(l.l_quan...
517-docdocumentationNoneThis is a table in the REGION table.\n\nThe fo...
616-docdocumentationNoneThis is a table in the PART table.\n\nThe foll...
71243-sqlsqlWhat are the top 10 customers with the highest...SELECT c.c_name as customer_name,\n sum(...
81239-sqlsqlWhat are the top 100 customers based on their ...SELECT c.c_name as customer_name,\n sum(...
913-docdocumentationNoneThis is a table in the SUPPLIER table.\n\nThe ...
101241-sqlsqlWhat are the top 10 customers in terms of tota...SELECT c.c_name as customer_name,\n sum(...
1112-docdocumentationNoneThis is a table in the LINEITEM table.\n\nThe ...
1218-docdocumentationNoneThis is a table in the NATION table.\n\nThe fo...
131248-sqlsqlHow many customers are in each country?SELECT n.n_name as country,\n count(*) a...
141240-sqlsqlWhat is the number of orders placed each week?SELECT date_trunc('week', o_orderdate) as week...
+
+ + + +# Removing Training Data +If you added some training data by mistake, you can remove it. Model performance is directly linked to the quality of the training data. + + +```python +vn.remove_training_data(id='my-training-data-id') ``` diff --git a/mkdocs.yml b/mkdocs.yml index 03600339..aa111681 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -3,12 +3,12 @@ nav: - Intro: - What is Vanna.AI?: index.md - Intro to Vanna: intro-to-vanna.md - - How Vanna Works: - - Onboarding: onboarding.md - - Adding Vanna to your Workflow: workflow.md - Use in Notebooks: - Asking Questions: notebooks/vn-ask.md - Training Vanna: notebooks/vn-train.md + - How Vanna Works: + - Onboarding: onboarding.md + - Adding Vanna to your Workflow: workflow.md - Other Ways to Use Vanna: - Use with Streamlit: streamlit.md - Databases: diff --git a/notebooks/vn-train.ipynb b/notebooks/vn-train.ipynb index a6100842..6ecb6de1 100644 --- a/notebooks/vn-train.ipynb +++ b/notebooks/vn-train.ipynb @@ -1,19 +1,461 @@ { "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![Vanna AI](https://img.vanna.ai/vanna-train.svg)\n", + "\n", + "The following notebook goes through the process of training Vanna. \n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vanna-ai/vanna-py/blob/main/notebooks/vn-train.ipynb)\n", + "\n", + "[![Open in GitHub](https://img.vanna.ai/github.svg)](https://github.com/vanna-ai/vanna-py/blob/main/notebooks/vn-ask.ipynb)\n", + "\n", + "# Install Vanna\n", + "First we install Vanna from [PyPI](https://pypi.org/project/vanna/) and import it.\n", + "Here, we'll also install the Snowflake connector. If you're using a different database, you'll need to install the appropriate connector." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install vanna\n", + "%pip install snowflake-connector-python" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import vanna as vn\n", + "import snowflake.connector" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Login\n", + "Creating a login and getting an API key is as easy as entering your email (after you run this cell) and entering the code we send to you. Check your Spam folder if you don't see the code." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "api_key = vn.get_api_key('my-email@example.com')\n", + "vn.set_api_key(api_key)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Set your Model\n", + "You need to choose a globally unique model name. Try using your company name or another unique string. All data from models are isolated - there's no leakage." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "vn.set_model('my-model') # Enter your model name here. This is a globally unique identifier for your model." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Automatic Training\n", + "If you'd like to use automatic training, the Vanna package can crawl your database to fetch metadata to train your model. You can put in your Snowflake credentials here. These details are only referenced within your notebook. These database credentials are never sent to Vanna's severs." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "vn.connect_to_snowflake(account='my-account', username='my-username', password='my-password', database='my-database')" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Trying query history\n", + "Trying INFORMATION_SCHEMA.COLUMNS for SNOWFLAKE_SAMPLE_DATA\n" + ] + }, + { + "data": { + "text/plain": [ + "Train on SQL: What are the top 10 customers ranked by total sales?\n", + "Train on SQL: What are the top 10 customers in terms of total sales?\n", + "Train on SQL: What are the top two customers with the highest total sales for each region?\n", + "Train on SQL: What are the top 5 customers with the highest total sales?\n", + "Train on SQL: What is the total quantity of each product sold in each region, ordered by region name and total quantity in descending order?\n", + "Train on SQL: What is the number of orders for each week, starting from the most recent week?\n", + "Train on SQL: What countries are in the region 'EUROPE'?\n", + "Train on Information Schema: SNOWFLAKE_SAMPLE_DATA.TPCH_SF1 SUPPLIER\n", + "Train on Information Schema: SNOWFLAKE_SAMPLE_DATA.TPCH_SF1 LINEITEM\n", + "Train on Information Schema: SNOWFLAKE_SAMPLE_DATA.TPCH_SF1 CUSTOMER\n", + "Train on Information Schema: SNOWFLAKE_SAMPLE_DATA.TPCH_SF1 PARTSUPP\n", + "Train on Information Schema: SNOWFLAKE_SAMPLE_DATA.TPCH_SF1 PART\n", + "Train on Information Schema: SNOWFLAKE_SAMPLE_DATA.TPCH_SF1 ORDERS\n", + "Train on Information Schema: SNOWFLAKE_SAMPLE_DATA.TPCH_SF1 REGION\n", + "Train on Information Schema: SNOWFLAKE_SAMPLE_DATA.TPCH_SF1 NATION" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "training_plan = vn.get_training_plan_experimental(filter_databases=['SNOWFLAKE_SAMPLE_DATA'], filter_schemas=['TPCH_SF1'])\n", + "training_plan" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vn.train(plan=training_plan)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Train with DDL Statements\n", + "If you prefer to manually train, you do not need to connect to a database. You can use the train function with other parmaeters like ddl" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vn.train(ddl=\"\"\"\n", + " CREATE TABLE IF NOT EXISTS my-table (\n", + " id INT PRIMARY KEY,\n", + " name VARCHAR(100),\n", + " age INT\n", + " )\n", + "\"\"\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Train with Documentation\n", + "Sometimes you may want to add documentation about your business terminology or definitions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vn.train(documentation=\"Our business defines OTIF score as the percentage of orders that are delivered on time and in full\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Train with SQL\n", + "You can also add SQL queries to your training data. This is useful if you have some queries already laying around. You can just copy and paste those from your editor to begin generating new SQL." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vn.train(sql=\"SELECT * FROM my-table WHERE name = 'John Doe'\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# View Training Data\n", + "At any time you can see what training data is in your model" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtraining_data_typequestioncontent
015-docdocumentationNoneThis is a table in the PARTSUPP table.\\n\\nThe ...
111-docdocumentationNoneThis is a table in the CUSTOMER table.\\n\\nThe ...
214-docdocumentationNoneThis is a table in the ORDERS table.\\n\\nThe fo...
31244-sqlsqlWhat are the names of the top 10 customers?SELECT c.c_name as customer_name\\nFROM snowf...
41242-sqlsqlWhat are the top 5 customers in terms of total...SELECT c.c_name AS customer_name, SUM(l.l_quan...
517-docdocumentationNoneThis is a table in the REGION table.\\n\\nThe fo...
616-docdocumentationNoneThis is a table in the PART table.\\n\\nThe foll...
71243-sqlsqlWhat are the top 10 customers with the highest...SELECT c.c_name as customer_name,\\n sum(...
81239-sqlsqlWhat are the top 100 customers based on their ...SELECT c.c_name as customer_name,\\n sum(...
913-docdocumentationNoneThis is a table in the SUPPLIER table.\\n\\nThe ...
101241-sqlsqlWhat are the top 10 customers in terms of tota...SELECT c.c_name as customer_name,\\n sum(...
1112-docdocumentationNoneThis is a table in the LINEITEM table.\\n\\nThe ...
1218-docdocumentationNoneThis is a table in the NATION table.\\n\\nThe fo...
131248-sqlsqlHow many customers are in each country?SELECT n.n_name as country,\\n count(*) a...
141240-sqlsqlWhat is the number of orders placed each week?SELECT date_trunc('week', o_orderdate) as week...
\n", + "
" + ], + "text/plain": [ + " id training_data_type \\\n", + "0 15-doc documentation \n", + "1 11-doc documentation \n", + "2 14-doc documentation \n", + "3 1244-sql sql \n", + "4 1242-sql sql \n", + "5 17-doc documentation \n", + "6 16-doc documentation \n", + "7 1243-sql sql \n", + "8 1239-sql sql \n", + "9 13-doc documentation \n", + "10 1241-sql sql \n", + "11 12-doc documentation \n", + "12 18-doc documentation \n", + "13 1248-sql sql \n", + "14 1240-sql sql \n", + "\n", + " question \\\n", + "0 None \n", + "1 None \n", + "2 None \n", + "3 What are the names of the top 10 customers? \n", + "4 What are the top 5 customers in terms of total... \n", + "5 None \n", + "6 None \n", + "7 What are the top 10 customers with the highest... \n", + "8 What are the top 100 customers based on their ... \n", + "9 None \n", + "10 What are the top 10 customers in terms of tota... \n", + "11 None \n", + "12 None \n", + "13 How many customers are in each country? \n", + "14 What is the number of orders placed each week? \n", + "\n", + " content \n", + "0 This is a table in the PARTSUPP table.\\n\\nThe ... \n", + "1 This is a table in the CUSTOMER table.\\n\\nThe ... \n", + "2 This is a table in the ORDERS table.\\n\\nThe fo... \n", + "3 SELECT c.c_name as customer_name\\nFROM snowf... \n", + "4 SELECT c.c_name AS customer_name, SUM(l.l_quan... \n", + "5 This is a table in the REGION table.\\n\\nThe fo... \n", + "6 This is a table in the PART table.\\n\\nThe foll... \n", + "7 SELECT c.c_name as customer_name,\\n sum(... \n", + "8 SELECT c.c_name as customer_name,\\n sum(... \n", + "9 This is a table in the SUPPLIER table.\\n\\nThe ... \n", + "10 SELECT c.c_name as customer_name,\\n sum(... \n", + "11 This is a table in the LINEITEM table.\\n\\nThe ... \n", + "12 This is a table in the NATION table.\\n\\nThe fo... \n", + "13 SELECT n.n_name as country,\\n count(*) a... \n", + "14 SELECT date_trunc('week', o_orderdate) as week... " + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vn.get_training_data()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Removing Training Data\n", + "If you added some training data by mistake, you can remove it. Model performance is directly linked to the quality of the training data." + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "vn.remove_training_data(id='my-training-data-id')" + ] } ], "metadata": { - "language_info": { - "name": "python" + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" }, - "orig_nbformat": 4 + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/pyproject.toml b/pyproject.toml index 491fa5f6..47c8c94e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vanna" -version = "0.0.16" +version = "0.0.17" authors = [ { name="Zain Hoda", email="zain@vanna.ai" }, ] diff --git a/src/vanna/__init__.py b/src/vanna/__init__.py index 54ff337c..3800942e 100644 --- a/src/vanna/__init__.py +++ b/src/vanna/__init__.py @@ -583,31 +583,167 @@ def remove_item(self, item: str): -def get_training_plan() -> TrainingPlan: +def __get_databases() -> List[str]: + try: + df_databases = run_sql("SELECT * FROM INFORMATION_SCHEMA.DATABASES") + except: + try: + df_databases = run_sql("SHOW DATABASES") + except: + return [] + + return df_databases['DATABASE_NAME'].unique().tolist() + +def __get_information_schema_tables(database: str) -> pd.DataFrame: + df_tables = run_sql(f'SELECT * FROM {database}.INFORMATION_SCHEMA.TABLES') + + return df_tables + + +def get_training_plan_experimental(filter_databases: Union[List[str], None] = None, filter_schemas: Union[List[str], None] = None, include_information_schema: bool = False, use_historical_queries: bool = True) -> TrainingPlan: """ + **EXPERIMENTAL** : This method is experimental and may change in future versions. + + Get a training plan based on the metadata in the database. Currently this only works for Snowflake. + **Example:** ```python - plan = vn.get_training_plan() + plan = vn.get_training_plan_experimental(filter_databases=["employees"], filter_schemas=["public"]) vn.train(plan=plan) ``` + """ - Get the training plan for the model. + plan = TrainingPlan([]) - Returns: - TrainingPlan: The training plan for the model. - """ - d = __rpc_call(method="get_training_plan", params=[]) + if run_sql is None: + raise ValidationError("Please connect to a database first.") - if 'result' not in d: - raise ValidationError("Failed to get training plan") + if use_historical_queries: + try: + print("Trying query history") + df_history = run_sql(""" select * from table(information_schema.query_history(result_limit => 5000)) order by start_time""") + + df_history_filtered = df_history.query('ROWS_PRODUCED > 1') + if filter_databases is not None: + mask = df_history_filtered['QUERY_TEXT'].str.lower().apply(lambda x: any(s in x for s in [s.lower() for s in filter_databases])) + df_history_filtered = df_history_filtered[mask] + + if filter_schemas is not None: + mask = df_history_filtered['QUERY_TEXT'].str.lower().apply(lambda x: any(s in x for s in [s.lower() for s in filter_schemas])) + df_history_filtered = df_history_filtered[mask] + + for query in df_history_filtered.sample(10)['QUERY_TEXT'].unique().tolist(): + plan._plan.append(TrainingPlanItem( + item_type=TrainingPlanItem.ITEM_TYPE_SQL, + item_group="", + item_name=generate_question(query), + item_value=query + )) - training_plan = TrainingPlan(**d['result']) + except Exception as e: + print(e) + + databases = __get_databases() + + for database in databases: + if filter_databases is not None and database not in filter_databases: + continue + + try: + df_tables = __get_information_schema_tables(database=database) + + print(f"Trying INFORMATION_SCHEMA.COLUMNS for {database}") + df_columns = run_sql(f"SELECT * FROM {database}.INFORMATION_SCHEMA.COLUMNS") + + for schema in df_tables['TABLE_SCHEMA'].unique().tolist(): + if filter_schemas is not None and schema not in filter_schemas: + continue + + if not include_information_schema and schema == "INFORMATION_SCHEMA": + continue + + df_columns_filtered_to_schema = df_columns.query(f"TABLE_SCHEMA == '{schema}'") + + try: + tables = df_columns_filtered_to_schema['TABLE_NAME'].unique().tolist() + + for table in tables: + df_columns_filtered_to_table = df_columns_filtered_to_schema.query(f"TABLE_NAME == '{table}'") + doc = f"The following columns are in the {table} table in the {database} database:\n\n" + doc += df_columns_filtered_to_table[["TABLE_CATALOG", "TABLE_SCHEMA", "TABLE_NAME", "COLUMN_NAME", "DATA_TYPE", "COMMENT"]].to_markdown() + + plan._plan.append(TrainingPlanItem( + item_type=TrainingPlanItem.ITEM_TYPE_IS, + item_group=f"{database}.{schema}", + item_name=table, + item_value=doc + )) + + except Exception as e: + print(e) + pass + except Exception as e: + print(e) + + # try: + # print("Trying SHOW TABLES") + # df_f = run_sql("SHOW TABLES") + + # for schema in df_f.schema_name.unique(): + # try: + # print(f"Trying GET_DDL for {schema}") + # ddl_df = run_sql(f"SELECT GET_DDL('schema', '{schema}')") + + # plan._plan.append(TrainingPlanItem( + # item_type=TrainingPlanItem.ITEM_TYPE_DDL, + # item_group=schema, + # item_name="All Tables", + # item_value=ddl_df.iloc[0, 0] + # )) + # except: + # pass + # except: + # try: + # print("Trying INFORMATION_SCHEMA.TABLES") + # df = run_sql("SELECT * FROM INFORMATION_SCHEMA.TABLES") + + # breakpoint() + + # try: + # print("Trying SCHEMATA") + # df_schemata = run_sql("SELECT * FROM region-us.INFORMATION_SCHEMA.SCHEMATA") + + # for schema in df_schemata.schema_name.unique(): + # df = run_sql(f"SELECT * FROM {schema}.information_schema.tables") + + # for table in df.table_name.unique(): + # plan._plan.append(TrainingPlanItem( + # item_type=TrainingPlanItem.ITEM_TYPE_IS, + # item_group=schema, + # item_name=table, + # item_value=None + # )) + + # try: + # ddl_df = run_sql(f"SELECT GET_DDL('schema', '{schema}')") + + # plan._plan.append(TrainingPlanItem( + # item_type=TrainingPlanItem.ITEM_TYPE_DDL, + # item_group=schema, + # item_name=None, + # item_value=ddl_df.iloc[0, 0] + # )) + # except: + # pass + # except: + # pass + + return plan - return training_plan def train(question: str = None, sql: str = None, ddl: str = None, documentation: str = None, json_file: str = None, - sql_file: str = None) -> bool: + sql_file: str = None, plan: TrainingPlan = None) -> bool: """ **Example:** ```python @@ -620,6 +756,7 @@ def train(question: str = None, sql: str = None, ddl: str = None, documentation: If you call it with the ddl argument, it's equivalent to [`add_ddl()`][vanna.add_ddl]. If you call it with the documentation argument, it's equivalent to [`add_documentation()`][vanna.add_documentation]. It can also accept a JSON file path or SQL file path to train on a batch of questions and SQL queries or a list of SQL queries respectively. + Additionally, you can pass a [`TrainingPlan`][vanna.TrainingPlan] object. Get a training plan with [`vn.get_training_plan_experimental()`][vanna.get_training_plan_experimental]. Args: question (str): The question to train on. @@ -628,6 +765,7 @@ def train(question: str = None, sql: str = None, ddl: str = None, documentation: json_file (str): The JSON file path. ddl (str): The DDL statement. documentation (str): The documentation to train on. + plan (TrainingPlan): The training plan to train on. """ if question and not sql: @@ -656,7 +794,7 @@ def train(question: str = None, sql: str = None, ddl: str = None, documentation: print("Adding Questions And SQLs using file:", json_file) for question in data: if not add_sql(question=question['question'], sql=question['answer']): - logger.warning(f"Not able to add sql for question: {question['question']} from {json_file}") + print(f"Not able to add sql for question: {question['question']} from {json_file}") return False return True @@ -676,11 +814,24 @@ def train(question: str = None, sql: str = None, ddl: str = None, documentation: if add_sql(question=question, sql=statement): print("SQL added!") return True - logger.warning("Not able to add sql.") + print("Not able to add sql.") return False return False - # Here we're going to attempt auto training + if plan: + for item in plan._plan: + if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL: + if not add_ddl(item.item_value): + print(f"Not able to add ddl for {item.item_group}") + return False + elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS: + if not add_documentation(item.item_value): + print(f"Not able to add documentation for {item.item_group}.{item.item_name}") + return False + elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL: + if not add_sql(question=item.item_name, sql=item.item_value): + print(f"Not able to add sql for {item.item_group}.{item.item_name}") + return False def flag_sql_for_review(question: str, sql: Union[str, None] = None, error_msg: Union[str, None] = None) -> bool: