-
Notifications
You must be signed in to change notification settings - Fork 39
/
Copy pathmain.py
151 lines (123 loc) · 5.64 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import argparse
import uuid
import uvicorn
from functools import partial
from fastapi import FastAPI, UploadFile
from fastapi.encoders import jsonable_encoder
from fastapi.responses import PlainTextResponse
from config import TEMP_DIR
os.makedirs(TEMP_DIR, exist_ok=True)
os.environ['PROMETHEUS_DISABLE_CREATED_SERIES'] = 'true'
# Specify mode
parser = argparse.ArgumentParser(description='Start service with different modes.')
parser.add_argument('--langchain', action='store_true')
parser.add_argument('--towhee', action='store_true')
parser.add_argument('--moniter', action='store_true')
parser.add_argument('--agent', action='store_true',
help='The default is False, which only works when `--langchain` is enabled.'
' It means using the agent in langchain to dynamically select tools.')
parser.add_argument('--max_observation', default=1000)
parser.add_argument('--name', default=str(uuid.uuid4()))
args = parser.parse_args()
app = FastAPI()
origins = ['*']
# Apply args
USE_LANGCHAIN = args.langchain
USE_TOWHEE = args.towhee
MAX_OBSERVATION = args.max_observation
ENABLE_MONITER = args.moniter
ENABLE_AGENT = args.agent
NAME = args.name
assert (USE_LANGCHAIN and not USE_TOWHEE ) or (USE_TOWHEE and not USE_LANGCHAIN), \
'The service should start with either "--langchain" or "--towhee".'
if USE_LANGCHAIN:
from src.langchain.operations import chat, insert, drop, check, get_history, clear_history, count # pylint: disable=C0413
chat = partial(chat, enable_agent=ENABLE_AGENT)
if USE_TOWHEE:
from src.towhee.operations import chat, insert, drop, check, get_history, clear_history, count # pylint: disable=C0413
if ENABLE_MONITER:
from moniter import enable_moniter # pylint: disable=C0413
from prometheus_client import generate_latest, REGISTRY # pylint: disable=C0413
enable_moniter(app, MAX_OBSERVATION, NAME)
@app.get('/metrics')
async def metrics():
registry = REGISTRY
data = generate_latest(registry)
return PlainTextResponse(content=data, media_type='text/plain')
@app.get('/')
def check_api():
res = jsonable_encoder({'status': True, 'msg': 'ok'}), 200
return res
@app.get('/answer')
def do_answer_api(session_id: str, project: str, question: str):
try:
new_question, final_answer = chat(session_id=session_id, project=project, question=question)
assert isinstance(final_answer, str)
return jsonable_encoder({
'status': True,
'msg': final_answer,
'debug': {
'original question': question,
'modified question': new_question,
'answer': final_answer,
}
}), 200
except Exception as e: # pylint: disable=W0703
return jsonable_encoder({'status': False, 'msg': f'Failed to answer question:\n{e}', 'code': 400}), 400
@app.post('/project/add')
def do_project_add_api(project: str, url: str = None, file: UploadFile = None):
assert url or file, 'You need to upload file or enter url of document to add data.'
try:
if url:
chunk_num, token_count = insert(data_src=url, project=project, source_type='url')
if file:
temp_file = os.path.join(TEMP_DIR, file.filename)
with open(temp_file, 'wb') as f:
content = file.file.read()
f.write(content)
chunk_num, token_count = insert(data_src=temp_file, project=project, source_type='file')
return jsonable_encoder({'status': True, 'msg': {
'chunk count': chunk_num,
'token count': token_count
}}), 200
except Exception as e: # pylint: disable=W0703
return jsonable_encoder({'status': False, 'msg': f'Failed to load data:\n{e}'}), 400
@app.post('/project/drop')
def do_project_drop_api(project: str):
# Drop data in vector db
try:
drop(project=project)
return jsonable_encoder({'status': True, 'msg': f'Dropped project: {project}'}), 200
except Exception as e: # pylint: disable=W0703
return jsonable_encoder({'status': False, 'msg': f'Failed to drop project:\n{e}'}), 400
@app.get('/project/check')
def do_project_check_api(project: str):
try:
status = check(project)
return jsonable_encoder({'status': True, 'msg': status}), 200
except Exception as e: # pylint: disable=W0703
return jsonable_encoder({'status': False, 'msg': f'Failed to check project:\n{e}'}), 400
@app.get('/project/count')
def do_project_count_api(project: str):
try:
counts = count(project)
return jsonable_encoder({'status': True, 'msg': counts}), 200
except Exception as e: # pylint: disable=W0703
return jsonable_encoder({'status': False, 'msg': f'Failed to count entities:\n{e}'}), 400
@app.get('/history/get')
def do_history_get_api(project: str, session_id: str = None):
try:
history = get_history(project=project, session_id=session_id)
return jsonable_encoder({'status': True, 'msg': history}), 200
except Exception as e: # pylint: disable=W0703
return jsonable_encoder({'status': False, 'msg': f'Failed to get history:\n{e}'}), 400
@app.get('/history/clear')
def do_history_clear_api(project: str, session_id: str = None):
try:
clear_history(project=project, session_id=session_id)
return jsonable_encoder({'status': True, 'msg': f'Successfully clear history for project {project} ({session_id}).'}), 200
except Exception as e: # pylint: disable=W0703
return jsonable_encoder({'status': False, 'msg': f'Failed to clear history:\n{e}'}), 400
if __name__ == '__main__':
uvicorn.run(app=app, host='0.0.0.0', port=8900)