Skip to content

Commit

Permalink
convert http::reponse into axum::response
Browse files Browse the repository at this point in the history
  • Loading branch information
getong committed Aug 21, 2024
1 parent fa1ccc8 commit c96d444
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 39 deletions.
20 changes: 9 additions & 11 deletions apps/indexer-proxy/proxy/src/ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ pub async fn connect_remote(
endpoint: String,
tx: Sender<String>,
req: Value,
state: MultipleQueryState,
is_test: bool,
state: Option<MultipleQueryState>,
) -> Result<()> {
let req_s = serde_json::to_string(&req).unwrap_or("".to_owned());
let request: RequestMessage =
Expand All @@ -77,8 +76,8 @@ pub async fn connect_remote(
}

// pay by real count
if !is_test {
pay_by_token(req_num, &tx, state.clone(), true).await?;
if state.is_some() {
pay_by_token(req_num, &tx, state.as_ref().unwrap().clone(), true).await?;
}

// open stream and send query to remote
Expand All @@ -98,8 +97,8 @@ pub async fn connect_remote(

// pay by real count
if count == BATCH {
if !is_test {
pay_by_token(count, &tx, state.clone(), false).await?;
if state.is_some() {
pay_by_token(count, &tx, state.as_ref().unwrap().clone(), false).await?;
}
count = 0;
}
Expand All @@ -110,8 +109,8 @@ pub async fn connect_remote(
}
}

if count != 0 && !is_test {
pay_by_token(count, &tx, state, true).await?;
if count != 0 && state.is_some() {
pay_by_token(count, &tx, state.unwrap(), true).await?;
}

Ok(())
Expand All @@ -120,14 +119,13 @@ pub async fn connect_remote(
pub fn api_stream(
endpoint: String,
req: Value,
state: MultipleQueryState,
is_test: bool,
state: Option<MultipleQueryState>,
) -> impl Stream<Item = String> {
let (tx, rx) = channel::<String>(1024);

tokio::spawn(async move {
let tx1 = tx.clone();
if let Err(err) = connect_remote(endpoint, tx1, req, state, is_test).await {
if let Err(err) = connect_remote(endpoint, tx1, req, state).await {
let state = build_data(err.to_json());
let _ = tx.send(state).await;
}
Expand Down
5 changes: 1 addition & 4 deletions apps/indexer-proxy/proxy/src/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,10 +457,7 @@ impl Project {
self.rpcquery_raw(body, endpoint, payment, network, no_sig, path)
.await?
}
ProjectType::Ai => {
self.rpcquery_raw(body, endpoint, payment, network, no_sig, path)
.await?
}
ProjectType::Ai => (vec![], String::new()),
};

Ok((d, s, waterlevel))
Expand Down
83 changes: 59 additions & 24 deletions apps/indexer-proxy/proxy/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,34 @@ async fn ep_wl_query(
deployment: String,
ep_name: String,
body: String,
) -> Result<Response<String>, Error> {
) -> AxumResponse {
if deployment != deployment_id {
return Err(Error::AuthVerify(1004));
return Error::AuthVerify(1004).into_response();
};

let (new_body, path) = match serde_json::from_str::<WhiteListBody>(&body) {
Ok(body) => (body.body, Some((body.path, body.method))),
Err(_) => (body, None),
Err(_) => (body.clone(), None),
};

let project = get_project(&deployment).await?;
let endpoint = project.endpoint(&ep_name, false)?;
let (data, signature, _limit) = project
let project = match get_project(&deployment).await {
Ok(p) => p,
Err(e) => return e.into_response(),
};
let endpoint = match project.endpoint(&ep_name, false) {
Ok(p) => p,
Err(e) => return e.into_response(),
};
if project.is_ai_project() {
let v = match serde_json::from_str::<Value>(&body.clone())
.map_err(|_| Error::Serialize(1142))
{
Ok(p) => p,
Err(e) => return e.into_response(),
};
return payg_stream(endpoint.endpoint.clone(), v, None).await;
}
let (data, signature, _limit) = match project
.check_query(
new_body,
endpoint.endpoint.clone(),
Expand All @@ -158,7 +173,11 @@ async fn ep_wl_query(
false,
path,
)
.await?;
.await
{
Ok(p) => p,
Err(e) => return e.into_response(),
};

let body = serde_json::to_string(&json!({
"result": general_purpose::STANDARD.encode(data),
Expand All @@ -168,22 +187,22 @@ async fn ep_wl_query(

let header = vec![("Content-Type", "application/json")];

Ok(build_response(body, header))
build_response(body, header).into_response()
}

async fn default_wl_query(
AuthWhitelistQuery(deployment_id): AuthWhitelistQuery,
Path(deployment): Path<String>,
body: String,
) -> Result<Response<String>, Error> {
) -> AxumResponse {
ep_wl_query(deployment_id, deployment, "default".to_owned(), body).await
}

async fn wl_query(
AuthWhitelistQuery(deployment_id): AuthWhitelistQuery,
Path((deployment, ep_name)): Path<(String, String)>,
body: String,
) -> Result<Response<String>, Error> {
) -> AxumResponse {
ep_wl_query(deployment_id, deployment, ep_name, body).await
}

Expand Down Expand Up @@ -272,7 +291,7 @@ async fn default_query(
AuthQuery(deployment_id): AuthQuery,
Path(deployment): Path<String>,
body: String,
) -> Result<Response<String>, Error> {
) -> AxumResponse {
ep_query_handler(
headers,
deployment_id,
Expand All @@ -288,7 +307,7 @@ async fn query_handler(
AuthQuery(deployment_id): AuthQuery,
Path((deployment, ep_name)): Path<(String, String)>,
body: String,
) -> Result<Response<String>, Error> {
) -> AxumResponse {
ep_query_handler(headers, deployment_id, deployment, ep_name, body).await
}

Expand All @@ -298,9 +317,9 @@ async fn ep_query_handler(
deployment: String,
ep_name: String,
body: String,
) -> Result<Response<String>, Error> {
) -> AxumResponse {
if COMMAND.auth() && deployment != deployment_id {
return Err(Error::AuthVerify(1004));
return Error::AuthVerify(1004).into_response();
};

let res_fmt = headers
Expand All @@ -311,12 +330,25 @@ async fn ep_query_handler(
.unwrap_or(HeaderValue::from_static("false"));
let no_sig = res_sig.to_str().map(|s| s == "true").unwrap_or(false);

let project = get_project(&deployment).await?;
let endpoint = project.endpoint(&ep_name, true)?;
let project = match get_project(&deployment).await {
Ok(p) => p,
Err(e) => return e.into_response(),
};
let endpoint = match project.endpoint(&ep_name, true) {
Ok(p) => p,
Err(e) => return e.into_response(),
};
if endpoint.is_ws {
return Err(Error::WebSocket(1315));
return Error::WebSocket(1315).into_response();
}
let (data, signature, limit) = project
if project.is_ai_project() {
let v = match serde_json::from_str::<Value>(&body).map_err(|_| Error::Serialize(1142)) {
Ok(p) => p,
Err(e) => return e.into_response(),
};
return payg_stream(endpoint.endpoint.clone(), v, None).await;
}
let (data, signature, limit) = match project
.check_query(
body,
endpoint.endpoint.clone(),
Expand All @@ -326,7 +358,11 @@ async fn ep_query_handler(
no_sig,
None,
)
.await?;
.await
{
Ok(p) => p,
Err(e) => return e.into_response(),
};

let (body, mut headers) = match res_fmt.to_str() {
Ok("inline") => (
Expand Down Expand Up @@ -354,7 +390,7 @@ async fn ep_query_handler(
headers.push(("X-RateLimit-Remaining-Second", (t - u).to_string().leak()));
}

Ok(build_response(body, headers))
build_response(body, headers).into_response()
}

async fn ws_query(
Expand Down Expand Up @@ -453,7 +489,7 @@ async fn ep_payg_handler(
Ok(p) => p,
Err(e) => return e.into_response(),
};
return payg_stream(endpoint.endpoint.clone(), v, state, false).await;
return payg_stream(endpoint.endpoint.clone(), v, Some(state)).await;
}

let (data, signature, state_data, limit) = match block.to_str() {
Expand Down Expand Up @@ -700,10 +736,9 @@ async fn ws_handler(
async fn payg_stream(
endpoint: String,
v: Value,
state: MultipleQueryState,
is_test: bool,
state: Option<MultipleQueryState>,
) -> AxumResponse {
let mut res = StreamBodyAs::text(api_stream(endpoint, v, state, is_test)).into_response();
let mut res = StreamBodyAs::text(api_stream(endpoint, v, state)).into_response();
res.headers_mut()
.insert("Content-Type", "text/event-stream".parse().unwrap());
res.headers_mut()
Expand Down

0 comments on commit c96d444

Please sign in to comment.