-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathchain.py
58 lines (48 loc) · 1.63 KB
/
chain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
from langchain.vectorstores import MongoDBAtlasVectorSearch
from pymongo import MongoClient
# Set DB
if os.environ.get("MONGO_URI", None) is None:
raise Exception("Missing `MONGO_URI` environment variable.")
MONGO_URI = os.environ["MONGO_URI"]
DB_NAME = "langchain"
COLLECTION_NAME = "vectorSearch"
ATLAS_VECTOR_SEARCH_INDEX_NAME = "default"
client = MongoClient(MONGO_URI)
db = client[DB_NAME]
MONGODB_COLLECTION = db[COLLECTION_NAME]
# Read from MongoDB Atlas Vector Search
vectorstore = MongoDBAtlasVectorSearch.from_connection_string(
MONGO_URI,
DB_NAME + "." + COLLECTION_NAME,
OpenAIEmbeddings(),
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME,
)
retriever = vectorstore.as_retriever(
search_type="similarity",
search_kwargs={"k": 100, "post_filter_pipeline": [{"$limit": 1}]}
)
# RAG prompt
template = """Answer the question based only on the following context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
# RAG
model = ChatOpenAI(model_name="gpt-3.5-turbo-16k-0613",temperature=0)
chain = (
RunnableParallel({"context": retriever,"question": RunnablePassthrough()})
| prompt
| model
| StrOutputParser()
)
# Add typing for input
class Question(BaseModel):
__root__: str
chain = chain.with_types(input_type=Question)