From 533e4c5247e6073ed25ffee11a51a7d14b4b839f Mon Sep 17 00:00:00 2001 From: Panos Vagenas <35837085+vagenas@users.noreply.github.com> Date: Tue, 5 Mar 2024 14:20:49 +0100 Subject: [PATCH] feat: add RAG exceptions (#171) Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com> --- deepsearch/cps/queries/results.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/deepsearch/cps/queries/results.py b/deepsearch/cps/queries/results.py index 58c121a3..d2ed88f7 100644 --- a/deepsearch/cps/queries/results.py +++ b/deepsearch/cps/queries/results.py @@ -31,16 +31,37 @@ class RAGAnswerItem(BaseModel): grounding: RAGGroundingInfo +class SemanticError(Exception): + pass + + +class GenerationError(SemanticError): + def __init__(self, msg="", *args, **kwargs): + err_msg = "There was an error during generation" + if msg: + err_msg += f": {msg}" + super().__init__(err_msg, *args, **kwargs) + + +class NoSearchResultsError(SemanticError): + def __init__(self, msg="Search returned no results", *args, **kwargs): + super().__init__(msg, *args, **kwargs) + + class RAGResult(BaseModel): answers: List[RAGAnswerItem] search_result_items: List[SearchResultItem] @classmethod - def from_api_output(cls, data: RunQueryResult): + def from_api_output(cls, data: RunQueryResult, raise_on_error=True): answers: List[RAGAnswerItem] = [] try: search_result_items = data.outputs["retrieval"]["items"] + if raise_on_error and len(search_result_items) == 0: + raise NoSearchResultsError() for answer_item in data.outputs["answers"]: + if raise_on_error and (gen_err := answer_item.get("gen_err")): + raise GenerationError(gen_err) answers.append( RAGAnswerItem( answer=answer_item["answer"], @@ -64,9 +85,11 @@ class SearchResult(BaseModel): search_result_items: List[SearchResultItem] @classmethod - def from_api_output(cls, data: RunQueryResult): + def from_api_output(cls, data: RunQueryResult, raise_on_error=True): try: search_result_items = data.outputs["items"] + if raise_on_error and len(search_result_items) == 0: + raise NoSearchResultsError() except KeyError: raise ValueError("Unexpected input format.") return SearchResult(