Skip to content

Commit

Permalink
feat: expose sem. params, default to hybrid search (#159)
Browse files Browse the repository at this point in the history
Signed-off-by: Panos Vagenas <[email protected]>
  • Loading branch information
vagenas authored Jan 22, 2024
1 parent 79fe58a commit 5fa0963
Showing 1 changed file with 110 additions and 16 deletions.
126 changes: 110 additions & 16 deletions deepsearch/cps/queries/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Any, Dict, List, Optional, Union

from pydantic.v1 import Field, validate_arguments
from typing_extensions import Annotated

from deepsearch.cps.client.components.elastic import ElasticSearchQuery
from deepsearch.cps.client.components.projects import Project, SemanticBackendResource
from deepsearch.cps.client.queries import Query, TaskCoordinates
Expand Down Expand Up @@ -77,17 +80,38 @@ def DataQuery(
return query


ConstrainedWeight = Annotated[
float, Field(strict=True, ge=0.0, le=1.0, multiple_of=0.1)
]


def CorpusRAGQuery(
question: str,
*,
project: Union[str, Project],
index_key: str,
retr_k: int = 10,
rerank: bool = False,
text_weight: ConstrainedWeight = 0.1,
) -> Query:

return _get_rag_query(
"""Create a RAG query against a collection
Args:
question (str): the natural-language query
project (Union[str, Project]): project to use
index_key (str): index key of target private collection (must already be semantically indexed)
retr_k (int, optional): num of items to retrieve; defaults to 10
rerank (bool, optional): whether to rerank retrieval results; defaults to False
text_weight (ConstrainedWeight, optional): lexical weight for hybrid search; allowed values: {0.0, 0.1, 0.2, ..., 1.0}; defaults to 0.1
"""

return _create_rag_query(
question=question,
project=project,
index_key=index_key,
retr_k=retr_k,
rerank=rerank,
text_weight=text_weight,
)


Expand All @@ -96,29 +120,56 @@ def DocumentRAGQuery(
*,
document_hash: str,
project: Union[str, Project],
index_key: Optional[str] = None, # set in case of private collection
index_key: Optional[str] = None,
retr_k: int = 10,
rerank: bool = False,
text_weight: ConstrainedWeight = 0.1,
) -> Query:

return _get_rag_query(
"""Create a RAG query against a specific document
Args:
question (str): the natural-language query
document_hash (str): hash of target document
project (Union[str, Project]): project to use
index_key (str, optional): index key of target private collection (must already be semantically indexed) in case doc within one; defaults to None (doc must already be semantically indexed)
retr_k (int, optional): num of items to retrieve; defaults to 10
rerank (bool, optional): whether to rerank retrieval results; defaults to False
text_weight (ConstrainedWeight, optional): lexical weight for hybrid search; allowed values: {0.0, 0.1, 0.2, ..., 1.0}; defaults to 0.1
"""

return _create_rag_query(
question=question,
document_hash=document_hash,
project=project,
index_key=index_key,
retr_k=retr_k,
rerank=rerank,
text_weight=text_weight,
)


def _get_rag_query(
@validate_arguments
def _create_rag_query(
question: str,
*,
document_hash: Optional[str] = None,
project: Union[str, Project],
index_key: Optional[str] = None,
index_key: Optional[str],
retr_k: int,
rerank: bool,
text_weight: ConstrainedWeight,
) -> Query:
proj_key = project.key if isinstance(project, Project) else project
idx_key = index_key or "__project__"

query = Query()
q_params = {"question": question}

q_params = {
"question": question,
"retr_k": retr_k,
"use_reranker": rerank,
"hybrid_search_text_weight": text_weight,
}
if document_hash:
q_params["doc_id"] = document_hash
task = query.add(
Expand All @@ -138,12 +189,28 @@ def CorpusSemanticQuery(
*,
project: Union[str, Project],
index_key: str,
retr_k: int = 10,
rerank: bool = False,
text_weight: ConstrainedWeight = 0.1,
) -> Query:

return _get_semantic_query(
"""Create a semantic retrieval query against a collection
Args:
question (str): the natural-language query
project (Union[str, Project]): project to use
index_key (str): index key of target private collection (must already be semantically indexed)
retr_k (int, optional): num of items to retrieve; defaults to 10
rerank (bool, optional): whether to rerank retrieval results; defaults to False
text_weight (ConstrainedWeight, optional): lexical weight for hybrid search; allowed values: {0.0, 0.1, 0.2, ..., 1.0}; defaults to 0.1
"""

return _create_semantic_query(
question=question,
project=project,
index_key=index_key,
retr_k=retr_k,
rerank=rerank,
text_weight=text_weight,
)


Expand All @@ -152,29 +219,56 @@ def DocumentSemanticQuery(
*,
document_hash: str,
project: Union[str, Project],
index_key: Optional[str] = None, # set in case of private collection
index_key: Optional[str] = None,
retr_k: int = 10,
rerank: bool = False,
text_weight: ConstrainedWeight = 0.1,
) -> Query:

return _get_semantic_query(
"""Create a semantic retrieval query against a specific document
Args:
question (str): the natural-language query
document_hash (str): hash of target document
project (Union[str, Project]): project to use
index_key (str, optional): index key of target private collection (must already be semantically indexed) in case doc within one; defaults to None (doc must already be semantically indexed)
retr_k (int, optional): num of items to retrieve; defaults to 10
rerank (bool, optional): whether to rerank retrieval results; defaults to False
text_weight (ConstrainedWeight, optional): lexical weight for hybrid search; allowed values: {0.0, 0.1, 0.2, ..., 1.0}; defaults to 0.1
"""

return _create_semantic_query(
question=question,
document_hash=document_hash,
project=project,
index_key=index_key,
retr_k=retr_k,
rerank=rerank,
text_weight=text_weight,
)


def _get_semantic_query(
@validate_arguments
def _create_semantic_query(
question: str,
*,
document_hash: Optional[str] = None,
project: Union[str, Project],
index_key: Optional[str] = None,
index_key: Optional[str],
retr_k: int,
rerank: bool,
text_weight: ConstrainedWeight,
) -> Query:
proj_key = project.key if isinstance(project, Project) else project
idx_key = index_key or "__project__"

query = Query()
q_params = {"question": question}

q_params = {
"question": question,
"retr_k": retr_k,
"use_reranker": rerank,
"hybrid_search_text_weight": text_weight,
}
if document_hash:
q_params["doc_id"] = document_hash
task = query.add(
Expand Down

0 comments on commit 5fa0963

Please sign in to comment.