Skip to content

Commit

Permalink
Merge pull request #44 from vanna-ai/delete-training-data
Browse files Browse the repository at this point in the history
remove training data
  • Loading branch information
zainhoda authored Jul 25, 2023
2 parents 36b787e + 25da193 commit cc25d4e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 5 deletions.
29 changes: 28 additions & 1 deletion src/vanna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
| `vn.add_` | Adds something to the dataset | [`vn.add_sql(...)`][vanna.add_sql] <br> [`vn.add_ddl(...)`][vanna.add_ddl] |
| `vn.generate_` | Generates something using AI based on the information in the dataset | [`vn.generate_sql(...)`][vanna.generate_sql] <br> [`vn.generate_explanation()`][vanna.generate_explanation] |
| `vn.run_` | Runs code (SQL or Plotly) | [`vn.run_sql`][vanna.run_sql] |
| `vn.remove_` | Removes something from the dataset | [`vn.remove_sql`][vanna.remove_sql] |
| `vn.remove_` | Removes something from the dataset | [`vn.remove_training_data`][vanna.remove_training_data] |
| `vn.update_` | Updates something in the dataset | [`vn.update_dataset_visibility(...)`][vanna.update_dataset_visibility] |
| `vn.connect_` | Connects to a database | [`vn.connect_to_snowflake(...)`][vanna.connect_to_snowflake] |
Expand Down Expand Up @@ -547,6 +547,33 @@ def remove_sql(question: str) -> bool:

return status.success

def remove_training_data(id: str) -> bool:
"""
Remove training data from the dataset
**Example:**
```python
vn.remove_training_data(id="1-ddl")
```
Args:
id (str): The ID of the training data to remove.
"""
params = [StringData(data=id)]

d = __rpc_call(method="remove_training_data", params=params)

if 'result' not in d:
raise Exception(f"Error removing training data")
return False

status = Status(**d['result'])

if not status.success:
raise Exception(f"Error removing training data: {status.message}")

return status.success

def generate_sql(question: str) -> str:
"""
**Example:**
Expand Down
20 changes: 16 additions & 4 deletions tests/test_vanna.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,6 @@ def test_flag_sql():
rv = vn.flag_sql_for_review(question="What's the data about student Jane Doe?")
assert rv == True

def test_get_training_data():
rv = vn.get_training_data()
assert rv.to_csv() == ",id,training_data_type,content\n0,3-sql,sql,SELECT * FROM students WHERE name = 'Jane Doe'\n"

def test_get_all_questions():
rv = vn.get_all_questions()
assert rv.shape == (3, 5)
Expand All @@ -188,3 +184,19 @@ def test_add_ddl():
rv = vn.add_ddl(ddl="This is the ddl")
assert rv == True

def test_add_sql2():
rv = vn.add_sql(question="How many students are there?", sql="SELECT * FROM students")
assert rv == True

def test_get_training_data():
rv = vn.get_training_data()
assert rv.shape == (3, 4)

def test_remove_training_data():
training_data = vn.get_training_data()

for index, row in training_data.iterrows():
rv = vn.remove_training_data(row['id'])
assert rv == True

assert vn.get_training_data().shape[0] == 2-index

0 comments on commit cc25d4e

Please sign in to comment.