diff --git a/Cargo.lock b/Cargo.lock index 144ce10c07..1e08fa523b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2011,6 +2011,7 @@ dependencies = [ "common-daft-config", "common-error", "common-file-formats", + "common-runtime", "daft-core", "daft-dsl", "daft-local-execution", diff --git a/src/common/runtime/src/lib.rs b/src/common/runtime/src/lib.rs index 2c8fc6acdd..df222fcfe9 100644 --- a/src/common/runtime/src/lib.rs +++ b/src/common/runtime/src/lib.rs @@ -69,13 +69,16 @@ impl Future for RuntimeTask { } pub struct Runtime { - runtime: tokio::runtime::Runtime, + pub runtime: Arc, pool_type: PoolType, } impl Runtime { pub(crate) fn new(runtime: tokio::runtime::Runtime, pool_type: PoolType) -> RuntimeRef { - Arc::new(Self { runtime, pool_type }) + Arc::new(Self { + runtime: Arc::new(runtime), + pool_type, + }) } async fn execute_task(future: F, pool_type: PoolType) -> DaftResult diff --git a/src/daft-connect/Cargo.toml b/src/daft-connect/Cargo.toml index bb846b9c46..d2b48e24ec 100644 --- a/src/daft-connect/Cargo.toml +++ b/src/daft-connect/Cargo.toml @@ -25,6 +25,7 @@ tokio = {version = "1.40.0", features = ["full"]} tonic = "0.12.3" tracing = {workspace = true} uuid = {version = "1.10.0", features = ["v4"]} +common-runtime.workspace = true [features] default = ["python"] diff --git a/src/daft-connect/src/execute.rs b/src/daft-connect/src/execute.rs index efca8f85b8..618570c68c 100644 --- a/src/daft-connect/src/execute.rs +++ b/src/daft-connect/src/execute.rs @@ -1,7 +1,7 @@ use std::{future::ready, sync::Arc}; use common_daft_config::DaftExecutionConfig; -use common_error::{DaftError, DaftResult}; +use common_error::DaftResult; use common_file_formats::FileFormat; use daft_dsl::LiteralValue; use daft_local_execution::NativeExecutor; @@ -62,16 +62,15 @@ impl Session { Runner::Native => { let this = self.clone(); - let result_stream = tokio::task::spawn_blocking(move || { - let plan = lp.optimize()?; - let cfg = Arc::new(DaftExecutionConfig::default()); - let native_executor = NativeExecutor::default(); - - let results = native_executor.run(&plan, &*this.psets, cfg, None)?; - let it = results.into_iter(); - Ok::<_, DaftError>(it.collect_vec()) - }) - .await??; + + let plan = lp.optimize()?; + let cfg = Arc::new(DaftExecutionConfig::default()); + let rt = common_runtime::get_compute_runtime(); + let native_executor = NativeExecutor::default().with_runtime(rt.runtime.clone()); + + let results = native_executor.run(&plan, &*this.psets, cfg, None)?; + let it = results.into_iter(); + let result_stream = it.collect_vec(); Ok(Box::pin(stream::iter(result_stream))) } @@ -91,8 +90,9 @@ impl Session { let (tx, rx) = tokio::sync::mpsc::channel::>(1); let this = self.clone(); + let rt = common_runtime::get_compute_runtime(); - tokio::spawn(async move { + rt.spawn(async move { let execution_fut = async { let translator = translation::SparkAnalyzer::new(&this); match command.rel_type { diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index 4980ce0d2f..eedc2a0559 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -81,10 +81,10 @@ pub fn start(addr: &str) -> eyre::Result { shutdown_signal: Some(shutdown_signal), port, }; + let runtime = common_runtime::get_io_runtime(true); std::thread::spawn(move || { - let runtime = tokio::runtime::Runtime::new().unwrap(); - let result = runtime.block_on(async { + let result = runtime.runtime.block_on(async { let incoming = { let listener = tokio::net::TcpListener::from_std(listener) .wrap_err("Failed to create TcpListener from std::net::TcpListener")?; diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index 4a01a63eb6..ef6cdbe93b 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -26,7 +26,7 @@ use common_runtime::{RuntimeRef, RuntimeTask}; use lazy_static::lazy_static; use progress_bar::{OperatorProgressBar, ProgressBarColor, ProgressBarManager}; use resource_manager::MemoryManager; -pub use run::{run_local, ExecutionEngineResult, NativeExecutor}; +pub use run::{ExecutionEngineResult, NativeExecutor}; use runtime_stats::{RuntimeStatsContext, TimedFuture}; use snafu::{futures::TryFutureExt, ResultExt, Snafu}; use tracing::Instrument; diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index 94f54eccee..4eca2e2011 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -9,7 +9,7 @@ use std::{ use common_daft_config::DaftExecutionConfig; use common_error::DaftResult; use common_tracing::refresh_chrome_trace; -use daft_local_plan::{translate, LocalPhysicalPlan}; +use daft_local_plan::translate; use daft_logical_plan::LogicalPlanBuilder; use daft_micropartition::{ partitioning::{InMemoryPartitionSetCache, MicroPartitionSet, PartitionSetCache}, @@ -62,6 +62,12 @@ pub struct PyNativeExecutor { executor: NativeExecutor, } +impl Default for PyNativeExecutor { + fn default() -> Self { + Self::new() + } +} + #[cfg(feature = "python")] #[pymethods] impl PyNativeExecutor { @@ -125,6 +131,7 @@ pub struct NativeExecutor { cancel: CancellationToken, runtime: Option>, pb_manager: Option>, + enable_explain_analyze: bool, } impl Default for NativeExecutor { @@ -132,17 +139,19 @@ impl Default for NativeExecutor { Self { cancel: CancellationToken::new(), runtime: None, - pb_manager: None, + pb_manager: should_enable_progress_bar().then(make_progress_bar_manager), + enable_explain_analyze: should_enable_explain_analyze(), } } } + impl NativeExecutor { pub fn new() -> Self { Self::default() } - pub fn with_runtime(mut self, runtime: tokio::runtime::Runtime) -> Self { - self.runtime = Some(Arc::new(runtime)); + pub fn with_runtime(mut self, runtime: Arc) -> Self { + self.runtime = Some(runtime); self } @@ -151,6 +160,11 @@ impl NativeExecutor { self } + pub fn enable_explain_analyze(mut self, b: bool) -> Self { + self.enable_explain_analyze = b; + self + } + pub fn run( &self, logical_plan_builder: &LogicalPlanBuilder, @@ -164,10 +178,12 @@ impl NativeExecutor { let cancel = self.cancel.clone(); let pipeline = physical_plan_to_pipeline(&physical_plan, psets, &cfg)?; let (tx, rx) = create_channel(results_buffer_size.unwrap_or(1)); + let rt = self.runtime.clone(); + let pb_manager = self.pb_manager.clone(); + let enable_explain_analyze = self.enable_explain_analyze; let handle = std::thread::spawn(move || { - let pb_manager = should_enable_progress_bar().then(make_progress_bar_manager); let runtime = rt.unwrap_or_else(|| { Arc::new( tokio::runtime::Builder::new_current_thread() @@ -204,7 +220,7 @@ impl NativeExecutor { _ => {} } } - if should_enable_explain_analyze() { + if enable_explain_analyze { let curr_ms = SystemTime::now() .duration_since(UNIX_EPOCH) .expect("Time went backwards") @@ -357,82 +373,3 @@ impl IntoIterator for ExecutionEngineResult { } } } - -pub fn run_local( - physical_plan: &LocalPhysicalPlan, - psets: &(impl PartitionSetCache> + ?Sized), - cfg: Arc, - results_buffer_size: Option, - cancel: CancellationToken, -) -> DaftResult { - refresh_chrome_trace(); - let pipeline = physical_plan_to_pipeline(physical_plan, psets, &cfg)?; - let (tx, rx) = create_channel(results_buffer_size.unwrap_or(1)); - let handle = std::thread::spawn(move || { - let pb_manager = should_enable_progress_bar().then(make_progress_bar_manager); - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("Failed to create tokio runtime"); - let execution_task = async { - let memory_manager = get_or_init_memory_manager(); - let mut runtime_handle = ExecutionRuntimeContext::new( - cfg.default_morsel_size, - memory_manager.clone(), - pb_manager, - ); - let receiver = pipeline.start(true, &mut runtime_handle)?; - - while let Some(val) = receiver.recv().await { - if tx.send(val).await.is_err() { - break; - } - } - - while let Some(result) = runtime_handle.join_next().await { - match result { - Ok(Err(e)) => { - runtime_handle.shutdown().await; - return DaftResult::Err(e.into()); - } - Err(e) => { - runtime_handle.shutdown().await; - return DaftResult::Err(Error::JoinError { source: e }.into()); - } - _ => {} - } - } - if should_enable_explain_analyze() { - let curr_ms = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("Time went backwards") - .as_millis(); - let file_name = format!("explain-analyze-{curr_ms}-mermaid.md"); - let mut file = File::create(file_name)?; - writeln!(file, "```mermaid\n{}\n```", viz_pipeline(pipeline.as_ref()))?; - } - Ok(()) - }; - - let local_set = tokio::task::LocalSet::new(); - local_set.block_on(&runtime, async { - tokio::select! { - biased; - () = cancel.cancelled() => { - log::info!("Execution engine cancelled"); - Ok(()) - } - _ = tokio::signal::ctrl_c() => { - log::info!("Received Ctrl-C, shutting down execution engine"); - Ok(()) - } - result = execution_task => result, - } - }) - }); - - Ok(ExecutionEngineResult { - handle, - receiver: rx, - }) -}