From c96d444685d87e2ee33b4d231ec02df7790efd7d Mon Sep 17 00:00:00 2001 From: gerald <3949379+getong@users.noreply.github.com> Date: Mon, 19 Aug 2024 17:02:21 +0800 Subject: [PATCH] convert http::reponse into axum::response --- apps/indexer-proxy/proxy/src/ai.rs | 20 +++--- apps/indexer-proxy/proxy/src/project.rs | 5 +- apps/indexer-proxy/proxy/src/server.rs | 83 ++++++++++++++++++------- 3 files changed, 69 insertions(+), 39 deletions(-) diff --git a/apps/indexer-proxy/proxy/src/ai.rs b/apps/indexer-proxy/proxy/src/ai.rs index 6044bad98..535d46043 100644 --- a/apps/indexer-proxy/proxy/src/ai.rs +++ b/apps/indexer-proxy/proxy/src/ai.rs @@ -58,8 +58,7 @@ pub async fn connect_remote( endpoint: String, tx: Sender, req: Value, - state: MultipleQueryState, - is_test: bool, + state: Option, ) -> Result<()> { let req_s = serde_json::to_string(&req).unwrap_or("".to_owned()); let request: RequestMessage = @@ -77,8 +76,8 @@ pub async fn connect_remote( } // pay by real count - if !is_test { - pay_by_token(req_num, &tx, state.clone(), true).await?; + if state.is_some() { + pay_by_token(req_num, &tx, state.as_ref().unwrap().clone(), true).await?; } // open stream and send query to remote @@ -98,8 +97,8 @@ pub async fn connect_remote( // pay by real count if count == BATCH { - if !is_test { - pay_by_token(count, &tx, state.clone(), false).await?; + if state.is_some() { + pay_by_token(count, &tx, state.as_ref().unwrap().clone(), false).await?; } count = 0; } @@ -110,8 +109,8 @@ pub async fn connect_remote( } } - if count != 0 && !is_test { - pay_by_token(count, &tx, state, true).await?; + if count != 0 && state.is_some() { + pay_by_token(count, &tx, state.unwrap(), true).await?; } Ok(()) @@ -120,14 +119,13 @@ pub async fn connect_remote( pub fn api_stream( endpoint: String, req: Value, - state: MultipleQueryState, - is_test: bool, + state: Option, ) -> impl Stream { let (tx, rx) = channel::(1024); tokio::spawn(async move { let tx1 = tx.clone(); - if let Err(err) = connect_remote(endpoint, tx1, req, state, is_test).await { + if let Err(err) = connect_remote(endpoint, tx1, req, state).await { let state = build_data(err.to_json()); let _ = tx.send(state).await; } diff --git a/apps/indexer-proxy/proxy/src/project.rs b/apps/indexer-proxy/proxy/src/project.rs index a13d148f1..8681af8a0 100644 --- a/apps/indexer-proxy/proxy/src/project.rs +++ b/apps/indexer-proxy/proxy/src/project.rs @@ -457,10 +457,7 @@ impl Project { self.rpcquery_raw(body, endpoint, payment, network, no_sig, path) .await? } - ProjectType::Ai => { - self.rpcquery_raw(body, endpoint, payment, network, no_sig, path) - .await? - } + ProjectType::Ai => (vec![], String::new()), }; Ok((d, s, waterlevel)) diff --git a/apps/indexer-proxy/proxy/src/server.rs b/apps/indexer-proxy/proxy/src/server.rs index 67b8ebf46..1670cec04 100644 --- a/apps/indexer-proxy/proxy/src/server.rs +++ b/apps/indexer-proxy/proxy/src/server.rs @@ -136,19 +136,34 @@ async fn ep_wl_query( deployment: String, ep_name: String, body: String, -) -> Result, Error> { +) -> AxumResponse { if deployment != deployment_id { - return Err(Error::AuthVerify(1004)); + return Error::AuthVerify(1004).into_response(); }; let (new_body, path) = match serde_json::from_str::(&body) { Ok(body) => (body.body, Some((body.path, body.method))), - Err(_) => (body, None), + Err(_) => (body.clone(), None), }; - let project = get_project(&deployment).await?; - let endpoint = project.endpoint(&ep_name, false)?; - let (data, signature, _limit) = project + let project = match get_project(&deployment).await { + Ok(p) => p, + Err(e) => return e.into_response(), + }; + let endpoint = match project.endpoint(&ep_name, false) { + Ok(p) => p, + Err(e) => return e.into_response(), + }; + if project.is_ai_project() { + let v = match serde_json::from_str::(&body.clone()) + .map_err(|_| Error::Serialize(1142)) + { + Ok(p) => p, + Err(e) => return e.into_response(), + }; + return payg_stream(endpoint.endpoint.clone(), v, None).await; + } + let (data, signature, _limit) = match project .check_query( new_body, endpoint.endpoint.clone(), @@ -158,7 +173,11 @@ async fn ep_wl_query( false, path, ) - .await?; + .await + { + Ok(p) => p, + Err(e) => return e.into_response(), + }; let body = serde_json::to_string(&json!({ "result": general_purpose::STANDARD.encode(data), @@ -168,14 +187,14 @@ async fn ep_wl_query( let header = vec![("Content-Type", "application/json")]; - Ok(build_response(body, header)) + build_response(body, header).into_response() } async fn default_wl_query( AuthWhitelistQuery(deployment_id): AuthWhitelistQuery, Path(deployment): Path, body: String, -) -> Result, Error> { +) -> AxumResponse { ep_wl_query(deployment_id, deployment, "default".to_owned(), body).await } @@ -183,7 +202,7 @@ async fn wl_query( AuthWhitelistQuery(deployment_id): AuthWhitelistQuery, Path((deployment, ep_name)): Path<(String, String)>, body: String, -) -> Result, Error> { +) -> AxumResponse { ep_wl_query(deployment_id, deployment, ep_name, body).await } @@ -272,7 +291,7 @@ async fn default_query( AuthQuery(deployment_id): AuthQuery, Path(deployment): Path, body: String, -) -> Result, Error> { +) -> AxumResponse { ep_query_handler( headers, deployment_id, @@ -288,7 +307,7 @@ async fn query_handler( AuthQuery(deployment_id): AuthQuery, Path((deployment, ep_name)): Path<(String, String)>, body: String, -) -> Result, Error> { +) -> AxumResponse { ep_query_handler(headers, deployment_id, deployment, ep_name, body).await } @@ -298,9 +317,9 @@ async fn ep_query_handler( deployment: String, ep_name: String, body: String, -) -> Result, Error> { +) -> AxumResponse { if COMMAND.auth() && deployment != deployment_id { - return Err(Error::AuthVerify(1004)); + return Error::AuthVerify(1004).into_response(); }; let res_fmt = headers @@ -311,12 +330,25 @@ async fn ep_query_handler( .unwrap_or(HeaderValue::from_static("false")); let no_sig = res_sig.to_str().map(|s| s == "true").unwrap_or(false); - let project = get_project(&deployment).await?; - let endpoint = project.endpoint(&ep_name, true)?; + let project = match get_project(&deployment).await { + Ok(p) => p, + Err(e) => return e.into_response(), + }; + let endpoint = match project.endpoint(&ep_name, true) { + Ok(p) => p, + Err(e) => return e.into_response(), + }; if endpoint.is_ws { - return Err(Error::WebSocket(1315)); + return Error::WebSocket(1315).into_response(); } - let (data, signature, limit) = project + if project.is_ai_project() { + let v = match serde_json::from_str::(&body).map_err(|_| Error::Serialize(1142)) { + Ok(p) => p, + Err(e) => return e.into_response(), + }; + return payg_stream(endpoint.endpoint.clone(), v, None).await; + } + let (data, signature, limit) = match project .check_query( body, endpoint.endpoint.clone(), @@ -326,7 +358,11 @@ async fn ep_query_handler( no_sig, None, ) - .await?; + .await + { + Ok(p) => p, + Err(e) => return e.into_response(), + }; let (body, mut headers) = match res_fmt.to_str() { Ok("inline") => ( @@ -354,7 +390,7 @@ async fn ep_query_handler( headers.push(("X-RateLimit-Remaining-Second", (t - u).to_string().leak())); } - Ok(build_response(body, headers)) + build_response(body, headers).into_response() } async fn ws_query( @@ -453,7 +489,7 @@ async fn ep_payg_handler( Ok(p) => p, Err(e) => return e.into_response(), }; - return payg_stream(endpoint.endpoint.clone(), v, state, false).await; + return payg_stream(endpoint.endpoint.clone(), v, Some(state)).await; } let (data, signature, state_data, limit) = match block.to_str() { @@ -700,10 +736,9 @@ async fn ws_handler( async fn payg_stream( endpoint: String, v: Value, - state: MultipleQueryState, - is_test: bool, + state: Option, ) -> AxumResponse { - let mut res = StreamBodyAs::text(api_stream(endpoint, v, state, is_test)).into_response(); + let mut res = StreamBodyAs::text(api_stream(endpoint, v, state)).into_response(); res.headers_mut() .insert("Content-Type", "text/event-stream".parse().unwrap()); res.headers_mut()