From 25da193bf99357ae3680e0f36fefc7b4ec8cd38e Mon Sep 17 00:00:00 2001
From: Zain Hoda <7146154+zainhoda@users.noreply.github.com>
Date: Mon, 24 Jul 2023 22:35:47 -0400
Subject: [PATCH] remove training data
---
src/vanna/__init__.py | 29 ++++++++++++++++++++++++++++-
tests/test_vanna.py | 20 ++++++++++++++++----
2 files changed, 44 insertions(+), 5 deletions(-)
diff --git a/src/vanna/__init__.py b/src/vanna/__init__.py
index b0bd605a..d240398c 100644
--- a/src/vanna/__init__.py
+++ b/src/vanna/__init__.py
@@ -29,7 +29,7 @@
| `vn.add_` | Adds something to the dataset | [`vn.add_sql(...)`][vanna.add_sql]
[`vn.add_ddl(...)`][vanna.add_ddl] |
| `vn.generate_` | Generates something using AI based on the information in the dataset | [`vn.generate_sql(...)`][vanna.generate_sql]
[`vn.generate_explanation()`][vanna.generate_explanation] |
| `vn.run_` | Runs code (SQL or Plotly) | [`vn.run_sql`][vanna.run_sql] |
-| `vn.remove_` | Removes something from the dataset | [`vn.remove_sql`][vanna.remove_sql] |
+| `vn.remove_` | Removes something from the dataset | [`vn.remove_training_data`][vanna.remove_training_data] |
| `vn.update_` | Updates something in the dataset | [`vn.update_dataset_visibility(...)`][vanna.update_dataset_visibility] |
| `vn.connect_` | Connects to a database | [`vn.connect_to_snowflake(...)`][vanna.connect_to_snowflake] |
@@ -547,6 +547,33 @@ def remove_sql(question: str) -> bool:
return status.success
+def remove_training_data(id: str) -> bool:
+ """
+ Remove training data from the dataset
+
+ **Example:**
+ ```python
+ vn.remove_training_data(id="1-ddl")
+ ```
+
+ Args:
+ id (str): The ID of the training data to remove.
+ """
+ params = [StringData(data=id)]
+
+ d = __rpc_call(method="remove_training_data", params=params)
+
+ if 'result' not in d:
+ raise Exception(f"Error removing training data")
+ return False
+
+ status = Status(**d['result'])
+
+ if not status.success:
+ raise Exception(f"Error removing training data: {status.message}")
+
+ return status.success
+
def generate_sql(question: str) -> str:
"""
**Example:**
diff --git a/tests/test_vanna.py b/tests/test_vanna.py
index ed25a534..d2056dfd 100644
--- a/tests/test_vanna.py
+++ b/tests/test_vanna.py
@@ -164,10 +164,6 @@ def test_flag_sql():
rv = vn.flag_sql_for_review(question="What's the data about student Jane Doe?")
assert rv == True
-def test_get_training_data():
- rv = vn.get_training_data()
- assert rv.to_csv() == ",id,training_data_type,content\n0,3-sql,sql,SELECT * FROM students WHERE name = 'Jane Doe'\n"
-
def test_get_all_questions():
rv = vn.get_all_questions()
assert rv.shape == (3, 5)
@@ -188,3 +184,19 @@ def test_add_ddl():
rv = vn.add_ddl(ddl="This is the ddl")
assert rv == True
+def test_add_sql2():
+ rv = vn.add_sql(question="How many students are there?", sql="SELECT * FROM students")
+ assert rv == True
+
+def test_get_training_data():
+ rv = vn.get_training_data()
+ assert rv.shape == (3, 4)
+
+def test_remove_training_data():
+ training_data = vn.get_training_data()
+
+ for index, row in training_data.iterrows():
+ rv = vn.remove_training_data(row['id'])
+ assert rv == True
+
+ assert vn.get_training_data().shape[0] == 2-index
\ No newline at end of file