diff --git a/src/vanna/__init__.py b/src/vanna/__init__.py index b0bd605a..d240398c 100644 --- a/src/vanna/__init__.py +++ b/src/vanna/__init__.py @@ -29,7 +29,7 @@ | `vn.add_` | Adds something to the dataset | [`vn.add_sql(...)`][vanna.add_sql]
[`vn.add_ddl(...)`][vanna.add_ddl] | | `vn.generate_` | Generates something using AI based on the information in the dataset | [`vn.generate_sql(...)`][vanna.generate_sql]
[`vn.generate_explanation()`][vanna.generate_explanation] | | `vn.run_` | Runs code (SQL or Plotly) | [`vn.run_sql`][vanna.run_sql] | -| `vn.remove_` | Removes something from the dataset | [`vn.remove_sql`][vanna.remove_sql] | +| `vn.remove_` | Removes something from the dataset | [`vn.remove_training_data`][vanna.remove_training_data] | | `vn.update_` | Updates something in the dataset | [`vn.update_dataset_visibility(...)`][vanna.update_dataset_visibility] | | `vn.connect_` | Connects to a database | [`vn.connect_to_snowflake(...)`][vanna.connect_to_snowflake] | @@ -547,6 +547,33 @@ def remove_sql(question: str) -> bool: return status.success +def remove_training_data(id: str) -> bool: + """ + Remove training data from the dataset + + **Example:** + ```python + vn.remove_training_data(id="1-ddl") + ``` + + Args: + id (str): The ID of the training data to remove. + """ + params = [StringData(data=id)] + + d = __rpc_call(method="remove_training_data", params=params) + + if 'result' not in d: + raise Exception(f"Error removing training data") + return False + + status = Status(**d['result']) + + if not status.success: + raise Exception(f"Error removing training data: {status.message}") + + return status.success + def generate_sql(question: str) -> str: """ **Example:** diff --git a/tests/test_vanna.py b/tests/test_vanna.py index ed25a534..d2056dfd 100644 --- a/tests/test_vanna.py +++ b/tests/test_vanna.py @@ -164,10 +164,6 @@ def test_flag_sql(): rv = vn.flag_sql_for_review(question="What's the data about student Jane Doe?") assert rv == True -def test_get_training_data(): - rv = vn.get_training_data() - assert rv.to_csv() == ",id,training_data_type,content\n0,3-sql,sql,SELECT * FROM students WHERE name = 'Jane Doe'\n" - def test_get_all_questions(): rv = vn.get_all_questions() assert rv.shape == (3, 5) @@ -188,3 +184,19 @@ def test_add_ddl(): rv = vn.add_ddl(ddl="This is the ddl") assert rv == True +def test_add_sql2(): + rv = vn.add_sql(question="How many students are there?", sql="SELECT * FROM students") + assert rv == True + +def test_get_training_data(): + rv = vn.get_training_data() + assert rv.shape == (3, 4) + +def test_remove_training_data(): + training_data = vn.get_training_data() + + for index, row in training_data.iterrows(): + rv = vn.remove_training_data(row['id']) + assert rv == True + + assert vn.get_training_data().shape[0] == 2-index \ No newline at end of file