From 1bec99ad938fc1246f36f7990c955a2ceec51a62 Mon Sep 17 00:00:00 2001 From: Aymane Boumaaza Date: Thu, 18 Apr 2024 14:25:26 +0100 Subject: [PATCH] Swap SELECT and WITH order in extract_sql --- src/vanna/base/base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 5360f989..eb1eca78 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -124,6 +124,11 @@ def generate_sql(self, question: str, **kwargs) -> str: return self.extract_sql(llm_response) def extract_sql(self, llm_response: str) -> str: + # If the llm_response contains a CTE (with clause), extract the sql bewteen WITH and ; + sql = re.search(r"WITH.*?;", llm_response, re.DOTALL) + if sql: + self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(0)}") + return sql.group(0) # If the llm_response is not markdown formatted, extract sql by finding select and ; in the response sql = re.search(r"SELECT.*?;", llm_response, re.DOTALL) if sql: @@ -131,11 +136,6 @@ def extract_sql(self, llm_response: str) -> str: ) return sql.group(0) - # If the llm_response contains a CTE (with clause), extract the sql bewteen WITH and ; - sql = re.search(r"WITH.*?;", llm_response, re.DOTALL) - if sql: - self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(0)}") - return sql.group(0) # If the llm_response contains a markdown code block, with or without the sql tag, extract the sql from it sql = re.search(r"```sql\n(.*)```", llm_response, re.DOTALL) if sql: